Mutable objects in Enzyme.jl

Hi everyone

I am new to automatic differentiation. I decided to use Enzyme since I need to handle mutating objects (even if I am not differentiating with respect to these objects).

I try to write a minimal, self-consistent, example:

using Enzyme

mutable struct my_struct
    p1::Vector{Float64}
    p2::Float64
end

function edit_str!(str::my_struct, x, i)
    str.p2 = str.p2 + 2x^2
    return nothing
end

function foo(x, str)
    edit_str!(str, x, 1)
    return str.p2
end

str = my_struct([2.0, 3.0], 1.0)

foo(1.0, str), autodiff(Reverse, foo, Active, Active(1.0), str)

Here I try to differentiate the function foo in x = 1.0. Note that the function edit_str! inside foo changes the object str. Yet, I do not need to differentiate with respect to str.

I found out that, depending on the mutations on str, the code may or may not work. In the above example it gives the right result 4.0, but if I change it slightly

function edit_str!(str::my_struct, x, i)
    str.p1 = str.p1 .+ 1.0
    str.p2 = str.p2 + 2x^2
    return nothing
end

it does not work anymore and it gives 0.0. Note that str.p1 is not needed anywhere. Perhaps even weirder, this works

function edit_str!(str::my_struct, x, i)
    str.p1 = 2.0 * str.p1
    str.p2 = str.p2 + 2x^2
    return nothing
end

I guess it depends on the kind of mutation on the mutable struct, but I would like to know if there is a robust way of dealing with this.

Thanks a lot.

1 Like

The following should work

autodiff(Reverse, foo, Active, Active(1.0), Duplicated(str, my_struct([0., 0.], 0.)))

See caveats.

3 Likes

Tank you very much for your answer. This works indeed, but how can I generalise? How can I know how to initialise a more complex mutable struct?

1 Like

In the second argument of Duplicated? I would initialize everything that is mutated to zero, but I don’t know if there is an official recommendation by Enzyme authors.

Thank you very much @jbrea

Here is the official answer from the developers:

Yeah essentially if you store active (e.g. differentiable) data into constant memory (such as the struct here), there’s some additional complications.

Specifically, since it’s stored into constant data, Enzyme can assume that uses of it in that constant data won’t impact the derivative – since its loading/using a constant!

If it acts as a buffer/temporary storage of derivative data, you’ll need to mark it as duplicated (see here: Home · Enzyme.jl). The reason for this is because if you build up some arbitrary data structure, we also need to construct the shadow (aka derivative) data structure. Without the shadow data structure, we have no location as memory to use for the temporary storage of the derivative!

Your first case is actually why I was very careful and said that Enzyme can assume that uses of it won’t impact the derivative, rather than Enzyme will assume. The reason in that case, is that the computation was simple enough that an optimization realized it was returning the input variable x and decided to return that directly rather than actually loading from the constant struct. Therefore it is returning an active value, and has a meaningful derivative. If you want guarantee that all uses of a variable will not have their uses change the derivative, we have a construct for that too. In essence it’s equivalent to making a custom rule that behaves identity-like, but is marked inactive.

1 Like

My question on how to initialise the shadow was related to the fact that if I change the operation on the structure, the initial value of 0.0 does not work anymore.

Example:

using Enzyme

mutable struct my_struct
    p1::Vector{Float64}
    p2::Float64
end

function edit_str!(str::my_struct, x, i)
    str.p1 = str.p1 .+ 1.0
    # Here I changed the operation to a product
    str.p2 = str.p2 * 2x^2
    return nothing
end

function foo(x, str)
    edit_str!(str, x, 1)
    return str.p2
end

str = my_struct([2.0, 3.0], 1.0)

str_copy = my_struct([0.0, 0.0], 0.0)

foo(1.0, str), autodiff(Reverse, foo, Active, Active(1.0), Duplicated(str, str_copy))

This gives 8.0 instead of 4.0. Initialising str_copy.p2 to other numbers only increases the result.

Do you know any workaround?

Your call to foo(1.0, str) changes the value of str. As a result it gives you a correct derivative for the new input values.


julia> using Enzyme


julia> mutable struct my_struct
           p1::Vector{Float64}
           p2::Float64
       end

julia> function edit_str!(str::my_struct, x, i)
           str.p1 = str.p1 .+ 1.0
           # Here I changed the operation to a product
           str.p2 = str.p2 * 2x^2
           return nothing
       end
edit_str! (generic function with 1 method)

julia> function foo(x, str)
           edit_str!(str, x, 1)
           return str.p2
       end
foo (generic function with 1 method)

julia> str = my_struct([2.0, 3.0], 1.0)
my_struct([2.0, 3.0], 1.0)

julia> str_copy = my_struct([0.0, 0.0], 0.0)
my_struct([0.0, 0.0], 0.0)

julia> foo(1.0, str)
2.0

julia> str
my_struct([3.0, 4.0], 2.0)

julia> str = my_struct([2.0, 3.0], 1.0)
my_struct([2.0, 3.0], 1.0)

julia> autodiff(Reverse, foo, Active, Active(1.0), Duplicated(str, str_copy))
((4.0, nothing),)

julia> str = my_struct([2.0, 3.0], 2.0)
my_struct([2.0, 3.0], 2.0)

julia> str_copy = my_struct([0.0, 0.0], 0.0)
my_struct([0.0, 0.0], 0.0)

julia> autodiff(Reverse, foo, Active, Active(1.0), Duplicated(str, str_copy))
((8.0, nothing),)

The expression here is equivalent to str.p2 * 2x^2 Its derivative is thus equivalent to 4x^2 str.p2 dx + 2x^2 d(str.p2). You set shadow(str.p2) = 0, and using active means dx = 1. Thus you should get ``4x^2 str.p2`. If, however, you did this on an input where p2 = 2 (like after your foo), a derivative of 8 is indeed the correct result.

4 Likes

Not sure what you’re trying to do. But you may be interested in Enzymes split mode which gives you the result of the original code and a separate pullback function and tape you can use to evaluate the derivatives of the values as they were at the time you called the original code.