Hi everyone,
I’m trying to use Enzyme.jl to compute gradients for a Flux model during training. My goal is to differentiate a custom loss function that combines a regular mean squared error term with an additional gradient-based force term computed via Enzyme’s forward mode AD.
However, I’m running into errors related to how Enzyme consider my variable ( costant or not )
Here’s the error i get:
Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
I tried to solve the issue applying suggestion b) but i can’t make it worl
Here is a minimal example of my setup:
using Flux, Enzyme
# Define a simple feedforward neural network model with two Dense layers
model = Chain(
Dense(3 => 2, sigmoid), # Input layer: 3 features to 2 neurons with sigmoid activation
Dense(2 => 1) # Output layer: 2 neurons to 1 output (regression)
)
ndata = 20 # Number of data points
# Generate random training data (input features and targets)
x_data = randn(Float32, 3, ndata) # Input matrix: 3 features × 20 samples
y_data = randn(Float32, 1, ndata) # Output vector: 1 target × 20 samples
# Pick first sample for initial testing
x = x_data[:, 1]
y = y_data[1]
opt = Descent(0.01) # Gradient descent optimizer with learning rate 0.01
# Define loss function accepting model parameters, input x, and target y
function loss(params, x, y)
n = length(y) # Number of target elements (usually 1 here)
println(typeof(params)) # Debug print to show parameter type
model = re(params) # Reconstruct model from flattened parameter vector
# Mean squared error loss function
energy_loss(x) = sum((model(x) .- y).^2) / n
# Compute "forces" as gradient of model output w.r.t input x using Enzyme Forward AD
# Divided by 3 and n as normalization factors (specific to your use case)
forces = Enzyme.gradient(Forward, (m, x) -> m(x), Const(energy_loss), x)[2] / 3 / n
# Return combined loss: energy loss plus a term involving x and forces
return energy_loss(x) + sum(x .* forces)
end
# Destructure model into a parameter vector `ps` and a reconstruction function `re`
ps, re = Flux.destructure(model)
println("Initial loss:", loss(ps, x, y)) # Print initial loss for sanity check
# Training loop for 20 epochs
for epoch in 1:20
for i in 1:ndata
x = x_data[:, i]
y = y_data[i]
# Compute gradients of loss w.r.t. parameters using Enzyme Reverse AD
grads = Enzyme.gradient(
Reverse,
(p, x, y) -> loss(p, x, y),
ps, # Parameter vector marked as active (duplicated for forward and backward)
Const(x), # Input x treated as constant (not differentiated)
Const(y) # Target y treated as constant
)[1]
println(grads) # Print gradients for debugging
# NOTE: Temporary exit after first gradient calculation for debug
# Remove this exit when ready to train fully
println("Stopping early for debugging.")
exit(0)
# Update parameters with computed gradients using Flux optimizer
Flux.Optimise.update!(opt, params, grads)
end
# Compute average loss over all training samples (optional, for monitoring)
l = mean(loss(params, x) for x in X)
println("Epoch $epoch: loss = $l")
end
Thank you in advance to anyone that can help me!