# Help on ForwardDiff and Zygote for ODEProblem with mass matrix

Hi there,

I’m new to `SciMLSensitivity.jl` and `Zygote.jl` and have been playing with these two libraries for a toy ODE system with a mass matrix. However, the Jacobians produced by the two libraries are different, and I haven’t been able to figure out why.

Could you please take a look at my minimal code example and give me a hint on what might be causing this issue?

Regards,
Charlie

``````using DifferentialEquations
using SciMLSensitivity
using Zygote
using ForwardDiff

K = [1.0 0 0
0 1.0 0
0 0 1.0]
function dudt(du, u, p, t)
du .= K*u
# nothing
end
M = [1.0 0 0
0 1.0 0
0 0 1.0]
f = ODEFunction(dudt, mass_matrix = M)

# target (vector) function
function forward_vector_target(u)
prob = ODEProblem(f, u, (0.0, 1.0), [1.0, 1.0, 1.0])
sol = solve(prob, Rodas5(), reltol = 1e-8, abstol = 1e-8)
return sol[end]
end

u0 = [1.0, 1.0, 1.0]

# 1. reserse mode
rs_jac = Zygote.jacobian(forward_vector_target, u0)[1]
println(rs_jac)
# output:
# 3×3 Matrix{Float64}:
#  2.71828  0.0      0.0
#  2.71828  2.71828  0.0
#  2.71828  2.71828  2.71828

# 2. forward mode
fs_jac = ForwardDiff.jacobian(forward_vector_target, u0)
println(fs_jac)
# output:
# 3×3 Matrix{Float64}:
#  2.71828  0.0      0.0
#  0.0      2.71828  0.0
#  0.0      0.0      2.71828
``````
1 Like

Here’s a very rough experiment which points towards the ForwardDiff one being correct:

``````julia> y0 = forward_vector_target(u0)
3-element Vector{Float64}:
2.7182818287174166
2.7182818287174166
2.7182818287174166

julia> forward_vector_target(u0 .+ [1,0,0] ./ 10) .- y0
3-element Vector{Float64}:
0.2718281828685605
-2.8883562208648073e-12
-2.8883562208648073e-12

julia> forward_vector_target(u0 .+ [0,1,0] ./ 10) .- y0
3-element Vector{Float64}:
-2.8883562208648073e-12
0.2718281828685605
-2.8883562208648073e-12

julia> forward_vector_target(u0 .+ [0,0,1] ./ 10) .- y0
3-element Vector{Float64}:
-2.8883562208648073e-12
-2.8883562208648073e-12
0.2718281828685605

julia> fs_jac = ForwardDiff.jacobian(forward_vector_target, u0)
3×3 Matrix{Float64}:
2.71828  0.0      0.0
0.0      2.71828  0.0
0.0      0.0      2.71828
``````

(I wondered whether a known bug in ForwardDiff v0.10 might be the problem, so I tried with ForwardDiff v0.11-DEV to check… but both versions give the same.)

1 Like

It looks like it’s a Zygote thing. When I try to recreate the Jacobian it’s fine:

``````julia> [Zygote.pullback(forward_vector_target, u0)[2]([1.0,0.0,0.0])[1]';
Zygote.pullback(forward_vector_target, u0)[2]([0.0,1.0,0.0])[1]';
Zygote.pullback(forward_vector_target, u0)[2]([0.0,0.0,1.0])[1]']
3×3 Matrix{Float64}:
2.71828  0.0      0.0
0.0      2.71828  0.0
0.0      0.0      2.71828
``````

That’s pulling back the basis vectors, and it seems fine. That’s what I would’ve assumed Zygote.jacobian was doing, so I’m surprised it’s a different answer. @mcabbott do you know what Zygote.jacobian actually calculating with?

``````julia> [Zygote.gradient((u)->forward_vector_target(u)[1], u0)[1]';
3×3 Matrix{Float64}:
2.71828  0.0      0.0
0.0      2.71828  0.0
0.0      0.0      2.71828
``````

Gradients are also fine, so this is very specific to the Zygote.jacobian interface, and I don’t know its implementation.

I didn’t manage to run the Zygote version locally.

But the difference between what you write and `Zygote.jacobian` is that you run the forward pass 3 times, while it is something more like this (untested):

``````y, bk = Zygote.pullback(forward_vector_target, u0)
hcat(bk([1,0,0.]), bk([0,1,0.]), bk([0,0,1.]))
``````

i.e. saving and re-using the pullback function.

If this snippet differs from what you write, then I believe the rule you are providing Zygote is invalid, e.g. it is overwriting some buffer captured by `bk` such that it does not work on the second call.

(Overwriting such buffers does allow more efficient rules, if the pullback is certain to only be used once. It’s easy to forget to test for this. At some point, there was work towards adding a flag for this to ChainRulesCore, and making Zygote use it, but IIRC this never got merged. Edit, see this CRC PR.)

1 Like

Okay, that narrows it down. Here’s a leaner MWE, it actually doesn’t involve mass matrices at all:

``````using OrdinaryDiffEq
using SciMLSensitivity
using Zygote
using ForwardDiff

function dudt(du, u, p, t)
du .= u
nothing
end

# target (vector) function
function forward_vector_target(u)
prob = ODEProblem(dudt, u, (0.0, 1.0), [1.0, 1.0, 1.0])
sol = solve(prob, FBDF(), reltol = 1e-8, abstol = 1e-8, sensealg = ForwardDiffSensitivity())
return sol[end]
end

y, back = Zygote.pullback(forward_vector_target, u0);
[back([1.0,0.0,0.0])[1]';
back([0.0,1.0,0.0])[1]';
back([0.0,0.0,1.0])[1]']
``````
``````julia> y, back = Zygote.pullback(forward_vector_target, u0);

julia> [back([1.0,0.0,0.0])[1]';
back([0.0,1.0,0.0])[1]';
back([0.0,0.0,1.0])[1]']
unthunk(du0) = [2.7182818436166185, 0.0, 0.0]
unthunk(du0) = [0.0, 2.7182818436166185, 0.0]
unthunk(du0) = [0.0, 0.0, 2.7182818436166185]
3×3 Matrix{Float64}:
2.71828  0.0      0.0
2.71828  2.71828  0.0
2.71828  2.71828  2.71828
``````

I added a printout of the `du0` computed by the adjoint method. That confirms that SciMLSensitivity is computing the correct value each time, so the adjoint system is completely fine with being reused with new values. So this means I’m handing

``````            (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
``````

back to Zygote and that `du0` is correct. The next thing above it is the `ODEProblem`, so I check:

``````function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
@show ȳ.u0
(NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
end

end
``````
``````unthunk(du0) = [2.7182818436166185, 0.0, 0.0]
ȳ.u0 = [2.7182818436166185, 0.0, 0.0]
unthunk(du0) = [0.0, 2.7182818436166185, 0.0]
ȳ.u0 = [2.7182818436166185, 2.7182818436166185, 0.0]
unthunk(du0) = [0.0, 0.0, 2.7182818436166185]
ȳ.u0 = [2.7182818436166185, 2.7182818436166185, 2.7182818436166185]
3×3 Matrix{Float64}:
2.71828  0.0      0.0
2.71828  2.71828  0.0
2.71828  2.71828  2.71828
``````

so the buffer that Zygote gives it has already merged the runs.

So since the pullback of `solve` is correct, what is changing the value before it ends up in the ODEProblem adjoint? That surely looks like Zygote bug then, since it should just be handling that same `du0` back up to the ODEProblem adjoint, which is then just the identity function derivative?

1 Like

Something closed over by `back` is certainly being mutated. For example, here the value of `back([0,1,0.])` is inconsistent:

``````julia> y, back = Zygote.pullback(forward_vector_target, u0);

julia>  du1, = back([1,0,0.])
([2.7182818436166185, 0.0, 0.0],)

julia> back([0,1,0.])  # second call
([2.7182818436166185, 2.7182818436166185, 0.0],)

julia> back([0,1,0.])  # third call, with same input as second
([2.7182818436166185, 5.436563687233237, 0.0],)

julia> (du1,)  # check the return wasn't mutated after printing
([2.7182818436166185, 0.0, 0.0],)

julia> y, back = Zygote.pullback(forward_vector_target, u0);

julia> back([0,1,0.])  # first call
([0.0, 2.7182818436166185, 0.0],)
``````

The returned arrays here have distinct `objectid`s, so it isn’t literally mutating the returned array.

What exactly is printing `unthunk(du0)` you haven’t shown. But I agree that stage looks correct. Surely the straight-through `rrule` you show isn’t introducing any mutable state.

Note that ChainRulesTestUtils.jl tests for this problem, so assuming you define `rrule`s, you can try to make a Zygote-free reproducer:

``````julia> using Zygote, ChainRulesCore, ChainRulesTestUtils

julia> sq1(x) = x .* x;

julia> sq2(x) = x .* x;  # same function, but will give it a bad rrule

julia> function ChainRulesCore.rrule(::typeof(sq1), x)
sq1(x), dy -> (NoTangent(), 2 .* x .* dy)  # good (at least for real numbers)
end

julia> function ChainRulesCore.rrule(::typeof(sq2), x)
BUF = zero.(x)
return sq2(x), dy -> (NoTangent(), copy(BUF .+= 2 .* x .* dy))  # bad!
end

([2.0, 4.0, 6.0],)

([2.0, 4.0, 6.0],)

([2.0, 4.0, 6.0],)

julia> Zygote.jacobian(x -> x.*x, [1,2,3.])
([2.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 6.0],)

julia> Zygote.jacobian(sq1, [1,2,3.])
([2.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 6.0],)

julia> Zygote.jacobian(sq2, [1,2,3.])  # wrong!
([2.0 0.0 0.0; 2.0 4.0 0.0; 2.0 4.0 6.0],)

julia> _, back = Zygote.pullback(sq2, [1,2,3.]);

julia> back([1,0,0.])
([2.0, 0.0, 0.0],)

julia> back([1,0,0.])  # second call, different answer
([4.0, 0.0, 0.0],)

julia> test_rrule(sq1, [1,2,3.])
Test Summary:                      | Pass  Total  Time
test_rrule: sq1 on Vector{Float64} |    7      7  0.0s
Test.DefaultTestSet("test_rrule: sq1 on Vector{Float64}", Any[], 7, false, false, true, 1.72291392558364e9, 1.7229139256149e9, false, "/Users/me/.julia/packages/ChainRulesTestUtils/Ko1Wr/src/testers.jl")

julia> test_rrule(sq2, [1,2,3.])
test_rrule: sq2 on Vector{Float64}: Test Failed at /Users/me/.julia/packages/ChainRulesTestUtils/Ko1Wr/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Problem: cotangent for input 2, 3-element Vector{Float64}
Evaluated: isapprox([14.12, -46.72, 77.88], [7.059999999999687, -23.36000000000044, 38.94000000000155]; rtol = 1.0e-9, atol = 1.0e-9)
``````
1 Like

When I isolate the SciMLSensitivity part, I can confirm it’s fine:

``````prob = ODEProblem(dudt, u0, (0.0, 1.0), [1.0, 1.0, 1.0])
ff(x) = DiffEqBase.solve_up(prob, ForwardDiffSensitivity(), x, prob.p, FBDF())[end]

y, back = Zygote.pullback(ff, u0);
[back([1.0,0.0,0.0])[1]';
back([0.0,1.0,0.0])[1]';
back([0.0,0.0,1.0])[1]']
``````
``````unthunk(du0) = [2.7196609758646866, 0.0, 0.0]
unthunk(du0) = [0.0, 2.7196609758646866, 0.0]
unthunk(du0) = [0.0, 0.0, 2.7196609758646866]
3×3 Matrix{Float64}:
2.71966  0.0      0.0
0.0      2.71966  0.0
0.0      0.0      2.71966
``````

Isolating it to the ODEProblem overload:

``````julia> g(x) = ODEProblem(dudt, x, (0.0, 1.0), [1.0, 1.0, 1.0]).u0
g (generic function with 1 method)

julia> y, back = Zygote.pullback(g, u0);

julia> [back([1.0,0.0,0.0])[1]';
back([0.0,1.0,0.0])[1]';
back([0.0,0.0,1.0])[1]']
ȳ.u0 = [1.0, 0.0, 0.0]
ȳ.u0 = [1.0, 1.0, 0.0]
ȳ.u0 = [1.0, 1.0, 1.0]
3×3 Matrix{Float64}:
1.0  0.0  0.0
1.0  1.0  0.0
1.0  1.0  1.0
``````

``````function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
(NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
end
end
``````

since that’s the only one involved on our end. So then here’s the new MWE:

``````using SciMLBase, Zygote
u0 = [1.0, 1.0, 1.0]
g(x) = ODEProblem((du,u,p,t)->du.=u, x, (0.0, 1.0), [1.0, 1.0, 1.0]).u0
y, back = Zygote.pullback(g, u0);
back([1.0,0.0,0.0])[1]
back([1.0,0.0,0.0])[1]
``````
``````
julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
1.0
0.0
0.0

julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
2.0
0.0
0.0
``````

Note that `ȳ.u0` is already wrong, so that means it’s not the SciML overload. There’s only one function left then, which is the `getproperty(prob, :u0)`.

The conclusion then is that the problem must be the Zygote literal_getproperty fallback?

1 Like

Indeed, I isolated it to Zygote’s getproperty fallback by simply defining a literal getproperty overload. Once that’s done, it goes away:

``````using SciMLBase, Zygote
u0 = [1.0, 1.0, 1.0]; odef(du,u,p,t)=nothing;
g(x) = ODEProblem(odef, x, (0.0, 1.0), [1.0, 1.0, 1.0]).u0
y, back = Zygote.pullback(g, u0);
back([1.0,0.0,0.0])[1]
back([1.0,0.0,0.0])[1]

::Val{:u0})
prob.u0, p̄ -> (ODEProblem(prob.f, p̄, prob.tspan, prob.p),)
end

y, back = Zygote.pullback(g, u0);
back([1.0,0.0,0.0])[1]
back([1.0,0.0,0.0])[1]
``````
``````julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
1.0
0.0
0.0

julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
2.0
0.0
0.0

::Val{:u0})
prob.u0, p̄ -> (ODEProblem(prob.f, p̄, prob.tspan, prob.p),)
end

julia> y, back = Zygote.pullback(g, u0);

julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
1.0
0.0
0.0

julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
1.0
0.0
0.0
``````

@mcabbott is there a mutation in that getproperty fallback?

Thanks very much for digging so deep into this problem!

I want to add that if I 1) define the ODEProblem out of the target function and 2) enforce its parameter `u0` with `remake` inside the function, then the pullback gives the correct jacobian. I’m not sure if this tiny observation is relevant at all, but I’m pasting the code (based on your MWE) here:

``````using OrdinaryDiffEq
using SciMLSensitivity
using Zygote
using ForwardDiff

function dudt(du, u, p, t)
du .= u
nothing
end

u0 = [1.0, 1.0, 1.0]
prob = ODEProblem(dudt, u0, (0.0, 1.0), [1.0, 1.0, 1.0])

# target (vector) function
function forward_vector_target(u)
_prob = remake(prob, u0 = u)
sol = solve(_prob, FBDF(), reltol = 1e-8, abstol = 1e-8, sensealg = ForwardDiffSensitivity())
return sol[end]
end

y, back = Zygote.pullback(forward_vector_target, u0);
[back([1.0,0.0,0.0])[1]';
back([0.0,1.0,0.0])[1]';
back([0.0,0.0,1.0])[1]']
``````