Help on ForwardDiff and Zygote for ODEProblem with mass matrix

Hi there,

I’m new to SciMLSensitivity.jl and Zygote.jl and have been playing with these two libraries for a toy ODE system with a mass matrix. However, the Jacobians produced by the two libraries are different, and I haven’t been able to figure out why.

Could you please take a look at my minimal code example and give me a hint on what might be causing this issue?

Regards,
Charlie

using DifferentialEquations
using SciMLSensitivity
using Zygote
using ForwardDiff

K = [1.0 0 0
     0 1.0 0
     0 0 1.0]
function dudt(du, u, p, t)
    du .= K*u
    # nothing
end
M = [1.0 0 0
     0 1.0 0
     0 0 1.0]
f = ODEFunction(dudt, mass_matrix = M)

# target (vector) function
function forward_vector_target(u)
    prob = ODEProblem(f, u, (0.0, 1.0), [1.0, 1.0, 1.0])
    sol = solve(prob, Rodas5(), reltol = 1e-8, abstol = 1e-8)
    return sol[end]
end

u0 = [1.0, 1.0, 1.0]

# 1. reserse mode
rs_jac = Zygote.jacobian(forward_vector_target, u0)[1]
println(rs_jac)
# output:
# 3×3 Matrix{Float64}:
#  2.71828  0.0      0.0
#  2.71828  2.71828  0.0
#  2.71828  2.71828  2.71828

# 2. forward mode
fs_jac = ForwardDiff.jacobian(forward_vector_target, u0)
println(fs_jac)
# output:
# 3×3 Matrix{Float64}:
#  2.71828  0.0      0.0
#  0.0      2.71828  0.0
#  0.0      0.0      2.71828
1 Like

Here’s a very rough experiment which points towards the ForwardDiff one being correct:

julia> y0 = forward_vector_target(u0)
3-element Vector{Float64}:
 2.7182818287174166
 2.7182818287174166
 2.7182818287174166

julia> forward_vector_target(u0 .+ [1,0,0] ./ 10) .- y0
3-element Vector{Float64}:
  0.2718281828685605
 -2.8883562208648073e-12
 -2.8883562208648073e-12

julia> forward_vector_target(u0 .+ [0,1,0] ./ 10) .- y0
3-element Vector{Float64}:
 -2.8883562208648073e-12
  0.2718281828685605
 -2.8883562208648073e-12

julia> forward_vector_target(u0 .+ [0,0,1] ./ 10) .- y0
3-element Vector{Float64}:
 -2.8883562208648073e-12
 -2.8883562208648073e-12
  0.2718281828685605

julia> fs_jac = ForwardDiff.jacobian(forward_vector_target, u0)
3×3 Matrix{Float64}:
 2.71828  0.0      0.0
 0.0      2.71828  0.0
 0.0      0.0      2.71828

(I wondered whether a known bug in ForwardDiff v0.10 might be the problem, so I tried with ForwardDiff v0.11-DEV to check… but both versions give the same.)

1 Like

It looks like it’s a Zygote thing. When I try to recreate the Jacobian it’s fine:

julia> [Zygote.pullback(forward_vector_target, u0)[2]([1.0,0.0,0.0])[1]';
        Zygote.pullback(forward_vector_target, u0)[2]([0.0,1.0,0.0])[1]';
        Zygote.pullback(forward_vector_target, u0)[2]([0.0,0.0,1.0])[1]']
3×3 Matrix{Float64}:
 2.71828  0.0      0.0
 0.0      2.71828  0.0
 0.0      0.0      2.71828

That’s pulling back the basis vectors, and it seems fine. That’s what I would’ve assumed Zygote.jacobian was doing, so I’m surprised it’s a different answer. @mcabbott do you know what Zygote.jacobian actually calculating with?

julia> [Zygote.gradient((u)->forward_vector_target(u)[1], u0)[1]';
       Zygote.gradient((u)->forward_vector_target(u)[2], u0)[1]';
       Zygote.gradient((u)->forward_vector_target(u)[3], u0)[1]']
3×3 Matrix{Float64}:
 2.71828  0.0      0.0
 0.0      2.71828  0.0
 0.0      0.0      2.71828

Gradients are also fine, so this is very specific to the Zygote.jacobian interface, and I don’t know its implementation.

I didn’t manage to run the Zygote version locally.

But the difference between what you write and Zygote.jacobian is that you run the forward pass 3 times, while it is something more like this (untested):

y, bk = Zygote.pullback(forward_vector_target, u0)
hcat(bk([1,0,0.]), bk([0,1,0.]), bk([0,0,1.]))

i.e. saving and re-using the pullback function.

If this snippet differs from what you write, then I believe the rule you are providing Zygote is invalid, e.g. it is overwriting some buffer captured by bk such that it does not work on the second call.

(Overwriting such buffers does allow more efficient rules, if the pullback is certain to only be used once. It’s easy to forget to test for this. At some point, there was work towards adding a flag for this to ChainRulesCore, and making Zygote use it, but IIRC this never got merged. Edit, see this CRC PR.)

1 Like

Okay, that narrows it down. Here’s a leaner MWE, it actually doesn’t involve mass matrices at all:

using OrdinaryDiffEq
using SciMLSensitivity
using Zygote
using ForwardDiff

function dudt(du, u, p, t)
    du .= u
    nothing
end

# target (vector) function
function forward_vector_target(u)
    prob = ODEProblem(dudt, u, (0.0, 1.0), [1.0, 1.0, 1.0])
    sol = solve(prob, FBDF(), reltol = 1e-8, abstol = 1e-8, sensealg = ForwardDiffSensitivity())
    return sol[end]
end

y, back = Zygote.pullback(forward_vector_target, u0);
[back([1.0,0.0,0.0])[1]';
back([0.0,1.0,0.0])[1]';
back([0.0,0.0,1.0])[1]']
julia> y, back = Zygote.pullback(forward_vector_target, u0);

julia> [back([1.0,0.0,0.0])[1]';
       back([0.0,1.0,0.0])[1]';
       back([0.0,0.0,1.0])[1]']
unthunk(du0) = [2.7182818436166185, 0.0, 0.0]
unthunk(du0) = [0.0, 2.7182818436166185, 0.0]
unthunk(du0) = [0.0, 0.0, 2.7182818436166185]
3×3 Matrix{Float64}:
 2.71828  0.0      0.0
 2.71828  2.71828  0.0
 2.71828  2.71828  2.71828

I added a printout of the du0 computed by the adjoint method. That confirms that SciMLSensitivity is computing the correct value each time, so the adjoint system is completely fine with being reused with new values. So this means I’m handing

            (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
                ntuple(_ -> NoTangent(), length(args))...)

back to Zygote and that du0 is correct. The next thing above it is the ODEProblem, so I check:

function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
    function ODEProblemAdjoint(ȳ)
        @show ȳ.u0
        (NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
    end

    ODEProblem(args...; kwargs...), ODEProblemAdjoint
end
unthunk(du0) = [2.7182818436166185, 0.0, 0.0]
ȳ.u0 = [2.7182818436166185, 0.0, 0.0]
unthunk(du0) = [0.0, 2.7182818436166185, 0.0]
ȳ.u0 = [2.7182818436166185, 2.7182818436166185, 0.0]
unthunk(du0) = [0.0, 0.0, 2.7182818436166185]
ȳ.u0 = [2.7182818436166185, 2.7182818436166185, 2.7182818436166185]
3×3 Matrix{Float64}:
 2.71828  0.0      0.0
 2.71828  2.71828  0.0
 2.71828  2.71828  2.71828

so the buffer that Zygote gives it has already merged the runs.

So since the pullback of solve is correct, what is changing the value before it ends up in the ODEProblem adjoint? That surely looks like Zygote bug then, since it should just be handling that same du0 back up to the ODEProblem adjoint, which is then just the identity function derivative?

1 Like

Something closed over by back is certainly being mutated. For example, here the value of back([0,1,0.]) is inconsistent:

julia> y, back = Zygote.pullback(forward_vector_target, u0);

julia>  du1, = back([1,0,0.])
([2.7182818436166185, 0.0, 0.0],)

julia> back([0,1,0.])  # second call
([2.7182818436166185, 2.7182818436166185, 0.0],)

julia> back([0,1,0.])  # third call, with same input as second
([2.7182818436166185, 5.436563687233237, 0.0],)

julia> (du1,)  # check the return wasn't mutated after printing
([2.7182818436166185, 0.0, 0.0],)

julia> y, back = Zygote.pullback(forward_vector_target, u0);

julia> back([0,1,0.])  # first call
([0.0, 2.7182818436166185, 0.0],)

The returned arrays here have distinct objectids, so it isn’t literally mutating the returned array.

What exactly is printing unthunk(du0) you haven’t shown. But I agree that stage looks correct. Surely the straight-through rrule you show isn’t introducing any mutable state.

Note that ChainRulesTestUtils.jl tests for this problem, so assuming you define rrules, you can try to make a Zygote-free reproducer:

julia> using Zygote, ChainRulesCore, ChainRulesTestUtils

julia> sq1(x) = x .* x;

julia> sq2(x) = x .* x;  # same function, but will give it a bad rrule

julia> function ChainRulesCore.rrule(::typeof(sq1), x)
         sq1(x), dy -> (NoTangent(), 2 .* x .* dy)  # good (at least for real numbers)
       end

julia> function ChainRulesCore.rrule(::typeof(sq2), x)
         BUF = zero.(x)
         return sq2(x), dy -> (NoTangent(), copy(BUF .+= 2 .* x .* dy))  # bad!
       end

julia> Zygote.gradient(x -> sum(x.*x), [1,2,3.])
([2.0, 4.0, 6.0],)

julia> Zygote.gradient(sum∘sq1, [1,2,3.])
([2.0, 4.0, 6.0],)

julia> Zygote.gradient(sum∘sq2, [1,2,3.])  # ok
([2.0, 4.0, 6.0],)

julia> Zygote.jacobian(x -> x.*x, [1,2,3.])
([2.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 6.0],)

julia> Zygote.jacobian(sq1, [1,2,3.])
([2.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 6.0],)

julia> Zygote.jacobian(sq2, [1,2,3.])  # wrong!
([2.0 0.0 0.0; 2.0 4.0 0.0; 2.0 4.0 6.0],)

julia> _, back = Zygote.pullback(sq2, [1,2,3.]);

julia> back([1,0,0.])
([2.0, 0.0, 0.0],)

julia> back([1,0,0.])  # second call, different answer
([4.0, 0.0, 0.0],)

julia> test_rrule(sq1, [1,2,3.])
Test Summary:                      | Pass  Total  Time
test_rrule: sq1 on Vector{Float64} |    7      7  0.0s
Test.DefaultTestSet("test_rrule: sq1 on Vector{Float64}", Any[], 7, false, false, true, 1.72291392558364e9, 1.7229139256149e9, false, "/Users/me/.julia/packages/ChainRulesTestUtils/Ko1Wr/src/testers.jl")

julia> test_rrule(sq2, [1,2,3.])
test_rrule: sq2 on Vector{Float64}: Test Failed at /Users/me/.julia/packages/ChainRulesTestUtils/Ko1Wr/src/check_result.jl:24
  Expression: isapprox(actual, expected; kwargs...)
  Problem: cotangent for input 2, 3-element Vector{Float64}
   Evaluated: isapprox([14.12, -46.72, 77.88], [7.059999999999687, -23.36000000000044, 38.94000000000155]; rtol = 1.0e-9, atol = 1.0e-9)
1 Like

When I isolate the SciMLSensitivity part, I can confirm it’s fine:

prob = ODEProblem(dudt, u0, (0.0, 1.0), [1.0, 1.0, 1.0])
ff(x) = DiffEqBase.solve_up(prob, ForwardDiffSensitivity(), x, prob.p, FBDF())[end]

y, back = Zygote.pullback(ff, u0);
[back([1.0,0.0,0.0])[1]';
back([0.0,1.0,0.0])[1]';
back([0.0,0.0,1.0])[1]']
unthunk(du0) = [2.7196609758646866, 0.0, 0.0]
unthunk(du0) = [0.0, 2.7196609758646866, 0.0]
unthunk(du0) = [0.0, 0.0, 2.7196609758646866]
3×3 Matrix{Float64}:
 2.71966  0.0      0.0
 0.0      2.71966  0.0
 0.0      0.0      2.71966

Isolating it to the ODEProblem overload:

julia> g(x) = ODEProblem(dudt, x, (0.0, 1.0), [1.0, 1.0, 1.0]).u0
g (generic function with 1 method)

julia> y, back = Zygote.pullback(g, u0);

julia> [back([1.0,0.0,0.0])[1]';
       back([0.0,1.0,0.0])[1]';
       back([0.0,0.0,1.0])[1]']
ȳ.u0 = [1.0, 0.0, 0.0]
ȳ.u0 = [1.0, 1.0, 0.0]
ȳ.u0 = [1.0, 1.0, 1.0]
3×3 Matrix{Float64}:
 1.0  0.0  0.0
 1.0  1.0  0.0
 1.0  1.0  1.0

So it is this overload:

function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
    function ODEProblemAdjoint(ȳ)
        (NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
    end
    ODEProblem(args...; kwargs...), ODEProblemAdjoint
end

since that’s the only one involved on our end. So then here’s the new MWE:

using SciMLBase, Zygote
u0 = [1.0, 1.0, 1.0]
g(x) = ODEProblem((du,u,p,t)->du.=u, x, (0.0, 1.0), [1.0, 1.0, 1.0]).u0
y, back = Zygote.pullback(g, u0);
back([1.0,0.0,0.0])[1]
back([1.0,0.0,0.0])[1]

julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
 1.0
 0.0
 0.0

julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
 2.0
 0.0
 0.0

Note that ȳ.u0 is already wrong, so that means it’s not the SciML overload. There’s only one function left then, which is the getproperty(prob, :u0).

The conclusion then is that the problem must be the Zygote literal_getproperty fallback?

1 Like

Indeed, I isolated it to Zygote’s getproperty fallback by simply defining a literal getproperty overload. Once that’s done, it goes away:

using SciMLBase, Zygote
u0 = [1.0, 1.0, 1.0]; odef(du,u,p,t)=nothing;
g(x) = ODEProblem(odef, x, (0.0, 1.0), [1.0, 1.0, 1.0]).u0
y, back = Zygote.pullback(g, u0);
back([1.0,0.0,0.0])[1]
back([1.0,0.0,0.0])[1]

Zygote.@adjoint function Zygote.literal_getproperty(prob::ODEProblem,
        ::Val{:u0})
    prob.u0, p̄ -> (ODEProblem(prob.f, p̄, prob.tspan, prob.p),)
end

y, back = Zygote.pullback(g, u0);
back([1.0,0.0,0.0])[1]
back([1.0,0.0,0.0])[1]
julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
 1.0
 0.0
 0.0

julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
 2.0
 0.0
 0.0

julia> Zygote.@adjoint function Zygote.literal_getproperty(prob::ODEProblem,
               ::Val{:u0})
           prob.u0, p̄ -> (ODEProblem(prob.f, p̄, prob.tspan, prob.p),)
       end

julia> y, back = Zygote.pullback(g, u0);

julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
 1.0
 0.0
 0.0

julia> back([1.0,0.0,0.0])[1]
3-element Vector{Float64}:
 1.0
 0.0
 0.0

@mcabbott is there a mutation in that getproperty fallback?

Thanks very much for digging so deep into this problem!

I want to add that if I 1) define the ODEProblem out of the target function and 2) enforce its parameter u0 with remake inside the function, then the pullback gives the correct jacobian. I’m not sure if this tiny observation is relevant at all, but I’m pasting the code (based on your MWE) here:

using OrdinaryDiffEq
using SciMLSensitivity
using Zygote
using ForwardDiff

function dudt(du, u, p, t)
    du .= u
    nothing
end

u0 = [1.0, 1.0, 1.0]
prob = ODEProblem(dudt, u0, (0.0, 1.0), [1.0, 1.0, 1.0])

# target (vector) function
function forward_vector_target(u)
    _prob = remake(prob, u0 = u)
    sol = solve(_prob, FBDF(), reltol = 1e-8, abstol = 1e-8, sensealg = ForwardDiffSensitivity())
    return sol[end]
end

y, back = Zygote.pullback(forward_vector_target, u0);
[back([1.0,0.0,0.0])[1]';
back([0.0,1.0,0.0])[1]';
back([0.0,0.0,1.0])[1]']