Enzyme Autodiff readonly error and working with batches of data

I have the following MWE

using Enzyme, Lux, Random
n = 10
x_batch = randn(2,n)
y_batch = randn(2,n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2,1,tanh)), Dense(2,1,tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);

function f(xb,yb)
    for k = 1 : n
        f1(x) = first(model((x, yb[:,k]), ps, st))[1]
        z = xb[:,k]
        dz = [0.0,0.0]
        Enzyme.autodiff(Enzyme.Reverse, f1, Active, Duplicated(z,dz))
    end
end

f(x_batch,y_batch) 

results in the following error :

Function argument passed to autodiff cannot be proven readonly.
If the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)

I have read the Enzyme docs but cannot gather why this happens. Further, is this the best way to compute gradients w.r.t. one of the inputs over batch?

Thank you

I have realized that the problem is the second argument to the network (y) which Enzyme is probably unable to prove readonly. The following code seems to work

using Enzyme, Lux, Random
n = 10
x_batch = randn(2, n)
y_batch = randn(2, n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2, 1, tanh)), Dense(2, 1, tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);

nnfunc(x,y) = first(model((x, y), ps, st))[1]

function f(xb, yb)
    for k = 1:n
        z = xb[:, k]
        dz = [0.0, 0.0]
        Enzyme.autodiff(Enzyme.Reverse, nnfunc, Active, Duplicated(z, dz), Duplicated(yb[:,k], zeros(2)))
        println(dz)
    end
end

f(x_batch, y_batch)

However, is this the best way to handle batch data?

This happens when the function is capturing some variables ( a closure/ callable struct) .

julia> f! =let a=[1.0,2.0,3.0]
       function f!(x,y)
           x.=a.*y
           nothing
       end
       end
f! (generic function with 1 method)

julia> Enzyme.autodiff(Reverse,f!,Const,Duplicated(rand(3),rand(3)),Duplicated(rand(3),zeros(3)))
ERROR: Function argument passed to autodiff cannot be proven readonly.
If the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)
See https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.
The potentially writing call is   store double %62, double addrspace(13)* %63, align 8, !dbg !201, !tbaa !189

this happens because the variables that f! captures requires shadow variables too. else you will get wrong results.

you can use make_zero(f!) to make an f! with zerod out variables.

julia> Enzyme.autodiff(Reverse,Duplicated(f!,make_zero(f!)),Const,Duplicated(rand(3),rand(3)),Duplicated(rand(3),zeros(3)))
((nothing, nothing),)

edit: the second case with nnfunc works because the variables are global, they are handled differently

I mean in this specific case, you don’t need to do that, and can just call something like

using Enzyme, Lux, Random
n = 10
x_batch = randn(2,n)
y_batch = randn(2,n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2,1,tanh)), Dense(2,1,tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);

function f(xb,yb)
    for k = 1 : n
        f1(x) = first(model((x, yb[:,k]), ps, st))[1]
        z = xb[:,k]
        dz = [0.0,0.0]
        Enzyme.autodiff(Enzyme.Reverse, Const(f1), Active, Duplicated(z,dz))
    end
end

f(x_batch,y_batch) 

Note the only change I made was making it Enzyme.autodiff(Enzyme.Reverse, Const(f1), Active, Duplicated(z,dz)) like was suggested in the error message (and I didn’t run it myself).

The error message specifically says that Function argument passed to autodiff cannot be proven readonly. In this case that is f1, which is indeed a closure. Unfortunately aliasing isn’t strong enough to prove that the function is read-only (which it is here), so Enzyme warns with that error and says to mark it Const like I did above if it is really read-only, or alternatively you can make it duplicated like @abitrandomuser showed above.

Thank you. This works.

Thank you. Does “make_zero” make the arguments to the function i.e., a, x, y all zero? What does the first “Const” mean in this case? I thought Duplicated(x,dx) evaluates the derivative at x and adds it to dx, if dx was a variable and Duplicated (y,zeros(2)) does not calculate the derivative of y. Therefore, I am curious what Duplicated(rand(3),rand(3)) does in this case.

The function f1 itself could contain data. In julia functions (like here closures) could contain data themselves.

For example.

struct MulBy
    x::Float64
end

function (func::MulBy)(y)
  return func.x * y
end

Here, we mark the function as constant, in that we don’t want to take the derivative wrt f1. This does not change the meaning for the other arguments.

I just put rand(3) as I simply wanted to demonstrate the call. You are right you should use some preassigned array there if you want to use the results.

The first Const is the return type of the function , since this function mutates inplace and returns nothing , it’s return is marked Const.

Thank you. I am not aware of functions containing data. Mathematically, f_1 : \mathbb{R}^2 \to \mathbb{R}. I did not understand the comment that “we don’t want to take the derivative w.r.t. f_1.” I am thinking of finding derivative of f_1 w.r.t. x \in \mathbb{R}^2. My be I am missing something or my knowledge of Julia/Enzyme is too poor.

Thank you. Is the reverse true, i.e. does Const return imply that the function does not return anything and does all its operations in place.