Something closed over by back
is certainly being mutated. For example, here the value of back([0,1,0.])
is inconsistent:
julia> y, back = Zygote.pullback(forward_vector_target, u0);
julia> du1, = back([1,0,0.])
([2.7182818436166185, 0.0, 0.0],)
julia> back([0,1,0.]) # second call
([2.7182818436166185, 2.7182818436166185, 0.0],)
julia> back([0,1,0.]) # third call, with same input as second
([2.7182818436166185, 5.436563687233237, 0.0],)
julia> (du1,) # check the return wasn't mutated after printing
([2.7182818436166185, 0.0, 0.0],)
julia> y, back = Zygote.pullback(forward_vector_target, u0);
julia> back([0,1,0.]) # first call
([0.0, 2.7182818436166185, 0.0],)
The returned arrays here have distinct objectid
s, so it isn’t literally mutating the returned array.
What exactly is printing unthunk(du0)
you haven’t shown. But I agree that stage looks correct. Surely the straight-through rrule
you show isn’t introducing any mutable state.
Note that ChainRulesTestUtils.jl tests for this problem, so assuming you define rrule
s, you can try to make a Zygote-free reproducer:
julia> using Zygote, ChainRulesCore, ChainRulesTestUtils
julia> sq1(x) = x .* x;
julia> sq2(x) = x .* x; # same function, but will give it a bad rrule
julia> function ChainRulesCore.rrule(::typeof(sq1), x)
sq1(x), dy -> (NoTangent(), 2 .* x .* dy) # good (at least for real numbers)
end
julia> function ChainRulesCore.rrule(::typeof(sq2), x)
BUF = zero.(x)
return sq2(x), dy -> (NoTangent(), copy(BUF .+= 2 .* x .* dy)) # bad!
end
julia> Zygote.gradient(x -> sum(x.*x), [1,2,3.])
([2.0, 4.0, 6.0],)
julia> Zygote.gradient(sum∘sq1, [1,2,3.])
([2.0, 4.0, 6.0],)
julia> Zygote.gradient(sum∘sq2, [1,2,3.]) # ok
([2.0, 4.0, 6.0],)
julia> Zygote.jacobian(x -> x.*x, [1,2,3.])
([2.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 6.0],)
julia> Zygote.jacobian(sq1, [1,2,3.])
([2.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 6.0],)
julia> Zygote.jacobian(sq2, [1,2,3.]) # wrong!
([2.0 0.0 0.0; 2.0 4.0 0.0; 2.0 4.0 6.0],)
julia> _, back = Zygote.pullback(sq2, [1,2,3.]);
julia> back([1,0,0.])
([2.0, 0.0, 0.0],)
julia> back([1,0,0.]) # second call, different answer
([4.0, 0.0, 0.0],)
julia> test_rrule(sq1, [1,2,3.])
Test Summary: | Pass Total Time
test_rrule: sq1 on Vector{Float64} | 7 7 0.0s
Test.DefaultTestSet("test_rrule: sq1 on Vector{Float64}", Any[], 7, false, false, true, 1.72291392558364e9, 1.7229139256149e9, false, "/Users/me/.julia/packages/ChainRulesTestUtils/Ko1Wr/src/testers.jl")
julia> test_rrule(sq2, [1,2,3.])
test_rrule: sq2 on Vector{Float64}: Test Failed at /Users/me/.julia/packages/ChainRulesTestUtils/Ko1Wr/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Problem: cotangent for input 2, 3-element Vector{Float64}
Evaluated: isapprox([14.12, -46.72, 77.88], [7.059999999999687, -23.36000000000044, 38.94000000000155]; rtol = 1.0e-9, atol = 1.0e-9)