Zygote push forward unable to differentiate through generic broadcast

Hi @ChrisRackauckas, I tried ForwardDiff over Zygote on a simpler net. I noted that reverse (Zygote) over forward (Zygote) is significantly faster than ForwardDiff over Zygote (reverse), this is most likely due to the fact that the forward diff now has to compute a one-to-many set of gradients. Please see the MWE below along with execution time. Maybe I misunderstood your message…most likely :slight_smile:

using Zygote
using ForwardDiff
using Statistics
using Flux

X = randn(2, 100)
data_ones = ones(2,100)

fc1 = Dense(2, 10, tanh)
fc2 = Dense(10, 10, tanh)
fc3 = Dense(10, 2, identity)

model = Chain(fc1, fc2, fc3)
θ, re = Flux.destructure(model)

function loss_FD_over_Z(x, θ)
    network = re(θ)
    ∂o = Zygote.gradient(p -> sum(network(p)), x)[1]
    mean((cos.(x) .- ∂o).^2)
end
loss_FD_over_Z(X, θ)

@time ∇θ = ForwardDiff.gradient(p -> loss_FD_over_Z(X, p), θ)
  4.893996 seconds (21.12 M allocations: 1.006 GiB, 4.20% gc time, 99.46% compilation time)

function loss_Z_over_Z(x)
    ∂o = Zygote.pushforward(x -> model(x), x)(data_ones)
    mean((cos.(x) .- ∂o).^2)
end
loss_Z_over_Z(X, θ)

@time ∇θ = Zygote.gradient(() -> loss_Z_over_Z(X), Zygote.Params(model))
  0.896909 seconds (43.72 k allocations: 3.084 MiB, 99.60% compilation time)