Lux, ComponentArrays and flat parameters : computing the gradient works with Zygote but not with Enzyme

Hi all, I’m quite new (on and off) to Julia and I’m trying to play more consistently with it.
I have an issue with AD : I can compute the gradient with respect to the weights of a very simple neural network with Zygote but it fails with Enzyme. I do not have the same problem when I compute the gradient with respect to the input of the neural network. I guess this is related to ComponentArrays.
Would you have any clue about that ? Thanks.

Running the code below, I get the following error :

MethodError: no method matching EnzymeCore.Duplicated(::Vector{Float32}, ::Vector{Float64})
Closest candidates are:
EnzymeCore.Duplicated(::T, !Matched::T) where T at ~/data/.julia/packages/EnzymeCore/cwKkU/src/EnzymeCore.jl:64

Here is the code :

using Lux, ComponentArrays, Random, Enzyme
import Zygote

rng = Random.MersenneTwister(1234)

Define a basic neural network structure

NN = Lux.Chain( Lux.Dense(5 => 5, tanh),
Lux.Dense(5 => 1) )

Setup the network

ps, st = Lux.setup(rng, NN)

Test the intialized network with some input values

xtest = [0.1, 0.2, 0.3, 0.4, 0.5]
NN(xtest, ps, st)[1][1]

No problem to compute the gradient with respect to the network input with Enzyme…

dx = zeros(size(xtest)[1])
autodiff(Reverse, xt → NN(xt, ps, st)[1][1], Active, Duplicated(xtest, dx))
dx

#…or Zygote
Zygote.gradient(xt → NN(xt, ps, st)[1][1], xtest)

############################################################

Now let’s try to differentiate wrt the weights of the neural network

ax_test = getaxes( ComponentArray(ps) )
θ_test = getdata( ComponentArray(ps) )

That works fine with Zygote…

Zygote.gradient( θ → NN(xtest, ComponentArray(θ, ax_test), st)[1][1], θ_test )

…but not with Enzyme

dθ = zeros(size(θ_test))
autodiff(Reverse, θ → NN(xtest, ComponentArray(θ, ax_init), st)[1][1], Active, Duplicated(θ_test, dθ))

1 Like

From the error message MethodError: no method matching EnzymeCore.Duplicated(::Vector{Float32}, ::Vector{Float64}) it looks like you’re trying to construct a duplicated with the primal of type Vector{Float32} and the shadow of type Vector{Float64}. The primal and shadow must be the same type.

What are the types of θ_test and ?

Also you should not create a capturing closure into the autodiff function, but instead pass the additional args. E.g.

autodiff(Reverse, (NN, xtest, θ, ax_init, st) → NN(xtest, ComponentArray(θ, ax_init), st)[1][1], Active, Const(NN), Const(xtest), Duplicated(θ_test, dθ), Const(ax_init), Const(st))

You were right θ_test is a Vector{Float32} (I guess it’s the default type for weights in Lux) while I define dθ with the zeros function which returns by default a Vector{Float64}.
However defining
dθ = zeros(Float32, size(θ_test))

and using the code you suggest :
autodiff(Reverse, (NN, xtest, θ, ax_init, st) → NN(xtest, ComponentArray(θ, ax_init), st)[1][1], Active, Const(NN), Const(xtest), Duplicated(θ_test, dθ), Const(ax_init), Const(st))

gives now the following error :
UndefVarError: θ not defined

That’s quite cryptic to me…

The theta Unicode symbol seems to have two different glyfps in the snippet (sorry I made my response on my phone). Can you just write theta in ascii instead of Unicode just to confirm.

Thank you for your reply. And sorry to all for the wrong code formatting in my previous messages.

  1. There was an error in my code, apologies : the ax_init variable doesn’t exist, it’s ax_test. So I don’t understand the returned error message related to θ

  2. My initial code autodiff(Reverse, θ -> NN(xtest, ComponentArray(θ, ax_test), st)[1][1], Active, Duplicated(θ_test, dθ)) runs and gives me the same value as with Zygote. However you advised no to calculate the gradient in such a way and, more problematic, every time I change xtest the line has to compile again and it takes ages…

  3. The code you suggested (corrected with ax_test instead of the wrong variable name ax_init) : autodiff(Reverse, (NN, xtest, θ, ax_test, st) → NN(xtest, ComponentArray(θ, ax_test), st)[1][1], Active, Const(NN), Const(xtest), Duplicated(θ_test, dθ), Const(ax_test), Const(st)) still gives me the following error : UndefVarError: θ not defined

If you have any clue about that, thanks a lot !

Mind posting an entire test.jl runnable file? Also did you try the unicode renaming (e.g. replacing all θ with theta, I’m still wondering if its a unicode copy paste bug that creates two different but similar looking symbols)?

Re 2, that’s great! The latter code is better because it means things will probably not be (type unstable) captured, which can be significant for both performance, and also because we’re not feature complete on Julia’s type unstable code (you’d get a very loud error if encountering something unsupported). The constant recompile time is odd, since it should be cached, but looking at the whole file will let me provide some better insights to it.

Sure, here is the code :

using Lux, ComponentArrays, Random, Enzyme
import Zygote

rng = Random.MersenneTwister(1234)

# Define a basic neural network structure
NN = Lux.Chain( Lux.Dense(5 => 5, tanh),
				Lux.Dense(5 => 1) )

# Setup the network
ps, st = Lux.setup(rng, NN)

# Test the intialized network with some input values
xtest = [0.1, 0.2, 0.3, 0.4, 0.5]
NN(xtest, ps, st)[1][1]

# No problem to compute the gradient with respect to the network input with Enzyme...
dx = zeros(size(xtest)[1])
autodiff(Reverse, xt -> NN(xt, ps, st)[1][1], Active, Duplicated(xtest, dx))
dx

#...or Zygote
Zygote.gradient(xt -> NN(xt, ps, st)[1][1], xtest)

############################################################
# Now let's try to differentiate wrt the weights of the neural network

ax_test = getaxes( ComponentArray(ps) )
θ_test = getdata( ComponentArray(ps) )
 
# No problem to differentiate with Zygote
Zygote.gradient( θ -> NN(xtest, ComponentArray(θ, ax_test), st)[1][1], θ_test )

# Works with Enzyme.autodiff with the following code
dθ = zeros(Float32 , size(θ_test))
autodiff(Reverse, θ -> NN(xtest .+ 0.01, ComponentArray(θ, ax_test), st)[1][1], Active, Duplicated(θ_test, dθ))
dθ

# But returns an error with this one
theta_test = copy(θ_test)
dtheta = copy(dθ) 
	
autodiff(Reverse, (NN, xtest, theta, ax_test, st) → NN(xtest, ComponentArray(theta, ax_test), st)[1][1], Active, Const(NN), Const(xtest), Duplicated(theta_test, dtheta), Const(ax_test), Const(st))

The last instruction returns :

UndefVarError: theta not defined

A few remarks :

  • as you can see, the problem doesn’t seem to be the character θ
  • I’m using Pluto.jl on Juliahub (very convenient by the way)

Hope this will help.

That’s right, the problem ist the arrow character for the anonymous function: it should be -> (like in the above lines), but it is in the last line. Pretty sneaky :sweat_smile:

I do get an error Enzyme execution failed., however, with this line (on Julia 1.8.5 and Enzyme 0.11.6)

autodiff(Reverse, θ -> NN(xtest .+ 0.01, ComponentArray(θ, ax_test), st)[1][1], Active, Duplicated(θ_test, dθ))

What’s the full error?

Enzyme execution failed just means that it caught an exception (and then should print the actual error type and message below).

1 Like

Doing a quick local test, your code (with only one parameter passed to the closure), type unstable captures the variables xtets, ax_test, st, etc – which leads to the currently unsupported error.

If you pass them in as additional args rather than capturing in the closure, like in my example above (however now fixing the arrow), it seems to work?

julia> autodiff(Reverse, (NN, xtest, theta, ax_test, st) -> NN(xtest, ComponentArray(theta, ax_test), st)[1][1], Active, Const(NN), Const(xtest), Duplicated(theta_test, dtheta), Const(ax_test), Const(st))

((nothing, nothing, nothing, nothing, nothing),)
1 Like

I encounter this discussion as I was trying to efficiently implement the gradient of a NN using Lux.jl with respect to the input parameters. I noticed that all the examples here mentioned work, both with Zygote.jl and Enzyme.jl, but the running time to compute the gradient seems to see very slow if we consider that we are just trying to differentiate a small neural network.

Here the example from @mothmani

using Lux, ComponentArrays, Random, Enzyme
import Zygote
using ReverseDiff
using Base: @time

rng = Random.MersenneTwister(1234)

# Define a basic neural network structure
NN = Lux.Chain( Lux.Dense(5 => 5, tanh),
				Lux.Dense(5 => 1) )

# Setup the network
ps, st = Lux.setup(rng, NN)

# Test the intialized network with some input values
xtest = [0.1, 0.2, 0.3, 0.4, 0.5]
NN(xtest, ps, st)[1][1]

# No problem to compute the gradient with respect to the network input with Enzyme...
dx = zeros(size(xtest)[1])
@time autodiff(Reverse, xt -> NN(xt, ps, st)[1][1], Active, Duplicated(xtest, dx))
dx

#...or Zygote
@time Zygote.gradient(xt -> NN(xt, ps, st)[1][1], xtest)

The example with Zygote takes ~0.20s while Enzyme does it in ~0.10s.

I have an use case where I need to compute the gradient of the NN with respect to the input layer multiple times in one training epoch, which leads to very large running times. Is there any way to make this code to run faster and/or compute the gradient with respect to multiple inputs at the same time? I wonder if one of the tricks that @wsmoses mentioned here may work.

The way you are timing the code globally, triggers recompilation for a new xt -> NN(xt, ps, st) every time. Your timing result should mention that most time is compilation.

@time Zygote.gradient(sum ∘ first ∘ NN, xtest, ps, st)
# 0.000220 seconds (107 allocations: 8.172 KiB)