Using Enzyme.jl with Flux: Issues Computing Gradients of a Model with Duplicated Parameters and Mixed Forward/Reverse AD

Hi everyone,

I’m trying to use Enzyme.jl to compute gradients for a Flux model during training. My goal is to differentiate a custom loss function that combines a regular mean squared error term with an additional gradient-based force term computed via Enzyme’s forward mode AD.

However, I’m running into errors related to how Enzyme consider my variable ( costant or not )
Here’s the error i get:

Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.

I tried to solve the issue applying suggestion b) but i can’t make it worl

Here is a minimal example of my setup:

using Flux, Enzyme

# Define a simple feedforward neural network model with two Dense layers
model = Chain(
    Dense(3 => 2, sigmoid),  # Input layer: 3 features to 2 neurons with sigmoid activation
    Dense(2 => 1)            # Output layer: 2 neurons to 1 output (regression)
)

ndata = 20  # Number of data points

# Generate random training data (input features and targets)
x_data = randn(Float32, 3, ndata)  # Input matrix: 3 features × 20 samples
y_data = randn(Float32, 1, ndata)  # Output vector: 1 target × 20 samples

# Pick first sample for initial testing
x = x_data[:, 1]
y = y_data[1]

opt = Descent(0.01)  # Gradient descent optimizer with learning rate 0.01

# Define loss function accepting model parameters, input x, and target y
function loss(params, x, y)
    n = length(y)  # Number of target elements (usually 1 here)
    println(typeof(params))  # Debug print to show parameter type

    model = re(params)  # Reconstruct model from flattened parameter vector

    # Mean squared error loss function
    energy_loss(x) = sum((model(x) .- y).^2) / n

    # Compute "forces" as gradient of model output w.r.t input x using Enzyme Forward AD
    # Divided by 3 and n as normalization factors (specific to your use case)
    forces = Enzyme.gradient(Forward, (m, x) -> m(x), Const(energy_loss), x)[2] / 3 / n

    # Return combined loss: energy loss plus a term involving x and forces
    return energy_loss(x) + sum(x .* forces)
end

# Destructure model into a parameter vector `ps` and a reconstruction function `re`
ps, re = Flux.destructure(model)

println("Initial loss:", loss(ps, x, y))  # Print initial loss for sanity check

# Training loop for 20 epochs
for epoch in 1:20
    for i in 1:ndata
        x = x_data[:, i]
        y = y_data[i]

        # Compute gradients of loss w.r.t. parameters using Enzyme Reverse AD
        grads = Enzyme.gradient(
            Reverse,
            (p, x, y) -> loss(p, x, y),
            ps,  # Parameter vector marked as active (duplicated for forward and backward)
            Const(x),             # Input x treated as constant (not differentiated)
            Const(y)              # Target y treated as constant
        )[1]

        println(grads)  # Print gradients for debugging

        # NOTE: Temporary exit after first gradient calculation for debug
        # Remove this exit when ready to train fully
        println("Stopping early for debugging.")
        exit(0)

        # Update parameters with computed gradients using Flux optimizer
        Flux.Optimise.update!(opt, params, grads)
    end

    # Compute average loss over all training samples (optional, for monitoring)
    l = mean(loss(params, x) for x in X)
    println("Epoch $epoch: loss = $l")
end

Thank you in advance to anyone that can help me!

If you set runtime activity on, like the error suggests, does it work?

in this case

      # Compute gradients of loss w.r.t. parameters using Enzyme Reverse AD
        grads = Enzyme.gradient(
            Enzyme.set_runtime_activity(Reverse),
            (p, x, y) -> loss(p, x, y),
            ps,  # Parameter vector marked as active (duplicated for forward and backward)
            Const(x),             # Input x treated as constant (not differentiated)
            Const(y)              # Target y treated as constant
        )[1]

oh you’re taking a second derivative. we should fix this, but in this case (higher order and also ML), I’d really recommend putting all of this in a reactant compile. This will remove the need for runtime activity, not have the higher order error you see, and generally run a lot faster in these kinds of cases.

Easiest fix is to change to Lux and run it with Reactant. That’s the low effort approach and will also make things a lot faster. There is a nice page for this too:

and most of this is automated with FromFluxAdapter if you’re lazy:

We setup this for NeuralPDE/NeuralOperators and found that by doing this autoconversion to Lux from Flux it fixed most of the AD performance issues and bugs, so now our standard support is to just always convert to Lux in those libraries. This is particularly important for higher order derivatives, such as is found in applications like PINNs, and it looks like you have some kind of a physics-informed something of the sort here that has the same behavior.

Hello, just to say I went through this thread and tried with Reactant.jl here is the code

using Reactant,Lux,Random,Optimisers,Enzyme

dev = reactant_device()

model = Chain(
    Dense(3,16, relu),
    Dense(16,16, relu),
    Dense(16,1)
)
rng = Random.default_rng(1234)
ps,st = Lux.setup(rng, model) |> dev

ndata = 20
T = Float32

x_data = randn(T,3,ndata) |> dev
y_data = randn(T,1,ndata) |> dev
opt = Optimisers.setup(Adam(0.01f0),ps)

function energy_loss(x,model,ps,st,y)
    r = model(x,ps,st)[1]
    return sum((r .- y).^2)
end

function loss(ps,model,st,x,y)
    n = length(y)
    g = Enzyme.gradient(Reverse,energy_loss,x,Const(model),Const(ps),Const(st),Const(y))[1]
    forces = g ./ n ./ 3
    return energy_loss(x,model,ps,st,y) + sum(x .* forces)
end

function getgrad(model,ps,st,x,y)
    dps = Enzyme.gradient(Reverse,loss,ps,Const(model),Const(st),Const(x),Const(y))[1]
    return dps
end
get_grad = @compile getgrad(model,ps,st,x_data,y_data)
get_grad(model,ps,st,x_data,y_data)

however I have no idea if its correct since I get the precision warning and seems big

E0000 00:00:1753734658.998318    5108 buffer_comparator.cc:147] Difference at 16: 1.77842, expected 6.94062
E0000 00:00:1753734658.998385    5108 buffer_comparator.cc:147] Difference at 17: 1.22021, expected 6.39888
E0000 00:00:1753734658.998391    5108 buffer_comparator.cc:147] Difference at 18: 0.833048, expected 5.19233
E0000 00:00:1753734658.998393    5108 buffer_comparator.cc:147] Difference at 19: 1.00677, expected 5.6695
E0000 00:00:1753734658.998395    5108 buffer_comparator.cc:147] Difference at 20: 0.743891, expected 5.56046
E0000 00:00:1753734658.998410    5108 buffer_comparator.cc:147] Difference at 21: 1.79193, expected 5.33211
E0000 00:00:1753734658.998414    5108 buffer_comparator.cc:147] Difference at 22: 1.36404, expected 5.33454
E0000 00:00:1753734658.998419    5108 buffer_comparator.cc:147] Difference at 23: 0.880308, expected 5.57699
E0000 00:00:1753734658.998423    5108 buffer_comparator.cc:147] Difference at 24: 1.50643, expected 4.88717
E0000 00:00:1753734658.998425    5108 buffer_comparator.cc:147] Difference at 25: 1.54981, expected 5.79724

Also not happening with Reactant.set_default_backend("cpu") and gradients are highly different. I thought using Forward inside is the thing but it just crashes.

EDIT : f32 issue everything is fine on f64, no idea why it works on cpu

oh that’s weird, though I guess GPUs are also weird. open an issue? (and cc @avikpal)

This is the buffer comparator issue between triton and cublas. Inside Lux, I set the precision to HIGH to avoid this (from my benchmarks this doesn’t really affect performance) Lux.jl/ext/LuxReactantExt/training.jl at e479795af82f9ca03422cc5bfc161c95fcedaec3 · LuxDL/Lux.jl · GitHub

You mean Reactant.jl? I never used it, it should work with Flux or as the response below suggests I should switch to Lux?

You may hit weird bugs with flux for now I think though it’s being worked on, you can try though

Lux very aggressively tests all (atleast most) of its layers with Reactant.

You may hit weird bugs with flux for now I think though it’s being worked on, you can try though

Most of the PRs made to Flux for Reactant testing/integration reinstate reactant's tests by CarloLucibello · Pull Request #2609 · FluxML/Flux.jl · GitHub, Enable other reactant tests by wsmoses · Pull Request #2600 · FluxML/Flux.jl · GitHub, Support for Reactant.jl by mcabbott · Pull Request #28 · FluxML/Fluxperimental.jl · GitHub have been stalled for ages now. I am not sure anyone is actively working on Flux integration (I could be wrong here)…

For the most part I think the issues in Flux + reactant are that the existing flux tests are wrong and can’t test reactant results (or more specifically the check_recursive utility assumes zygote-style results of nothing instead of enzyme-style shadows), and/or things that would improve easy of use rather than functionality. So if something goes wrong, definitely open an issue!

That said, like I’d also recommend using Lux here, the team is really active/nice and actively tests all the pieces against Reactant explicitly (in other words @avikpal you’re the real mvp).