Zygote push forward unable to differentiate through generic broadcast

Good day all,
I am struggling a bit with something simple, I hope. I would like to use Zygote.pushforward to AD neural network outputs wrt inputs. For standard fully connected nets everything works fine, but as soon as I use a modified network with residual connections I get the following error message:
Generic broadcast of * not support yet
Is there anyway to get around this? Below is a MWE. Any help will be appreciated.

using Zygote
using Statistics
using Flux

data = randn(2,10)
data_ones = ones(2,10)

fc1 = Dense(2,50)
fc2 = Dense(50,50)
fc3 = Dense(50,6)
fc_U = Dense(2,50)

function model(x)
   U = fc_U(x)
   h = fc1(x)
   h = fc2(h) .* U
   h = fc3(h)
   h
end
test = model(data)
grad_fun = Zygote.pushforward(model, data)
grad_fun(data_ones)

This is a broader issue in Zygote, I believe, and it affects some other downstream issues, such as NeuralPDE.jl only using numerical derivatives for interior loss calculations (see: question about pinn solver · Issue #283 · SciML/NeuralPDE.jl · GitHub). It would be great if this were resolved soon, since python DL frameworks allow it.

Here’s another MWE that generates a Generic broadcast of * not supported yet error.

using Flux

m = Dense(5, 1; bias=false)
p = rand(Float32, 5, 1)

function diaghessian(m, x)
	∇m(x) = Flux.pushforward(m, x)(1)
	Flux.pushforward(x -> ∇m(x)[1], x)(1)
end

diaghessian(m, p)

Zygote’s forward mode is, in my understanding, an abandoned prototype.

The second one can be done Zygote.diaghessian(sum∘m, p)[1], although it’s zero as the function is linear. This uses ForwardDiff over Zygote.

Hi All, thanks for the comments.
This is quite unfortunate. The only way I could get a PINN working using CUDA and Zygote was to use forward mode AD (pushforward) for the network output-input gradients and reverse mode for the parameter gradients. I will have to wait for Diffractor.jl, hopefully that sorts out the mentioned issue. In the meantime time I will revert to JAX in Python.
Thanks again.

The forward AD which actually works is ForwardDiff, and Zygote uses this internally for many things. Whether it can be made to work for your problem is hard to guess, Dual numbers do work with CUDA but probably not for everything. It’s not so clear from your example why reverse mode won’t work; Zygote does by default compute gradients for everything, including data, not just parameters.

Hi @mcabbott, I have used ForwardDiff for the internal gradients before (as shown below) but I always get the Mutating arrays not supported error. The MWE below works perfectly with Zygote.pushforward, but with this I can’t generically broadcast the network layer operations (as shown in my initial example).

using Zygote
using ForwardDiff
using Statistics

X = randn(2, 100)

W1 = randn(10, 2)
b1 = zeros(10)
W2 = randn(10, 10)
b2 = zeros(10)
W3 = randn(2, 10)
b3 = zeros(2)

θ = Zygote.Params([W1, b1, W2, b2, W3, b3])

function model(x)
    h = tanh.(W1 * x .+ b1)
    h = tanh.(W2 * h .+ b2)
    h = W3 * h .+ b3
    h
end

function loss(x)
    ∂o = ForwardDiff.gradient(x -> sum(model(x)), x)
    mean((cos.(x) .- ∂o).^2)
end

∇θ = Zygote.gradient(() -> loss(X), θ)

This code runs, but the use of ForwardDiff.gradient(f, x) within Zygote needs a giant warning sign that it does not track gradient with respect to f, and thus you get zero:

julia> function loss(x)
           ∂o = ForwardDiff.gradient(x -> sum(model(x)), x)
           mean((cos.(x) .- ∂o).^2)
       end
loss (generic function with 1 method)

julia> ∇θ = Zygote.gradient(() -> loss(X), θ)
Grads(...)

julia> ∇θ[W1] == nothing
true

I had plans somewhere to automate the warning, but I don’t think this can be made to track the result. It is a second derivative, and the second derivatives that work best are ForwardDiff over Zygote (e.g. this is how Zygote.hessian works). But to do that, you cannot use implicit Params, so you would have to re-organise things to start with an explicit parameter vector (or several).

You can also experiment with reverse over reverse. The simplest idea doesn’t work; it’s possible that avoiding implicit paramters would improve this, and it’s possible that using say ReverseDiff + Zygote may improve things.

julia> function loss(x)
           ∂o = sum(gradient(x -> sum(model(x)), x))
           mean((cos.(x) .- ∂o).^2)
       end
loss (generic function with 1 method)

julia> ∇θ = Zygote.gradient(() -> loss(X), θ)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ./iddict.jl:102 [inlined]
  [3] (::typeof(∂(get)))(Δ::Nothing)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
1 Like

Thanks for the detailed response @mcabbott, I will try the ReverseDiff over Zygote strategy you proposed. Thanks.

Reverse over reverse doesn’t even make sense asymptotically in terms of flops. So not only does forward over reverse work, but it’s also going to be asymptotically faster. It’s likely the thing to use.

Hi @ChrisRackauckas, I tried ForwardDiff over Zygote on a simpler net. I noted that reverse (Zygote) over forward (Zygote) is significantly faster than ForwardDiff over Zygote (reverse), this is most likely due to the fact that the forward diff now has to compute a one-to-many set of gradients. Please see the MWE below along with execution time. Maybe I misunderstood your message…most likely :slight_smile:

using Zygote
using ForwardDiff
using Statistics
using Flux

X = randn(2, 100)
data_ones = ones(2,100)

fc1 = Dense(2, 10, tanh)
fc2 = Dense(10, 10, tanh)
fc3 = Dense(10, 2, identity)

model = Chain(fc1, fc2, fc3)
θ, re = Flux.destructure(model)

function loss_FD_over_Z(x, θ)
    network = re(θ)
    ∂o = Zygote.gradient(p -> sum(network(p)), x)[1]
    mean((cos.(x) .- ∂o).^2)
end
loss_FD_over_Z(X, θ)

@time ∇θ = ForwardDiff.gradient(p -> loss_FD_over_Z(X, p), θ)
  4.893996 seconds (21.12 M allocations: 1.006 GiB, 4.20% gc time, 99.46% compilation time)

function loss_Z_over_Z(x)
    ∂o = Zygote.pushforward(x -> model(x), x)(data_ones)
    mean((cos.(x) .- ∂o).^2)
end
loss_Z_over_Z(X, θ)

@time ∇θ = Zygote.gradient(() -> loss_Z_over_Z(X), Zygote.Params(model))
  0.896909 seconds (43.72 k allocations: 3.084 MiB, 99.60% compilation time)

Reverse over forward should be fine as well. It’s really that double reverse is usually not a great idea.