Mooncake+ODE gradient error

I’m trying out Mooncake.jl to obtain gradients through ODEs, but I keep getting uninitialized tangent issues and I’m not sure how to fix it.

I’ve attached a small example problem below. Using DifferentiationInterface.jl gave the same error, so to simplify I’m just using the Mooncake functions.

ForwardDiff.jl and Zygote.jl both work without any modification, so it feels like I’m missing something obvious since this seems like it should be very simple. I’ve tried on 1.11 and 1.12 with the same issue in both cases.

using OrdinaryDiffEq, SciMLSensitivity, Mooncake

function testode!(du, u, p, t)
    du[1] = p[1]-u[1]
end

u0 = [1.0];
params = [0.0]
prob = ODEProblem(testode!, u0, (0.0, 1.0), params);
sol = solve(prob);

#test loss function
function loss(ps)
    newprob = remake(prob, p=ps)
    newsol = solve(newprob)
    return sum(Array(newsol))
end

loss(params) #this works

diffcache = prepare_gradient_cache(loss, params) #this fails

Error message:

ERROR: Trying to convert uninitialized tangent to ChainRules tangent.
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] to_cr_tangent
    @ C:\Users\johnb\.julia\packages\LinearSolve\re36c\ext\LinearSolveMooncakeExt.jl:30 [inlined]
  [3] map (repeats 4 times)
    @ .\tuple.jl:358 [inlined]
  [4] map(f::typeof(Mooncake.to_cr_tangent), t::Tuple{…})
    @ Base .\tuple.jl:358
  [5] map
    @ .\tuple.jl:358 [inlined]
  [6] map(::Function, ::@NamedTuple{…})
    @ Base .\namedtuple.jl:266
  [7] to_cr_tangent(t::Mooncake.MutableTangent{@NamedTuple{…}})
    @ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\tools_for_rules.jl:331
  [8] map
    @ .\tuple.jl:357 [inlined]
  [9] map (repeats 6 times)
    @ .\tuple.jl:358 [inlined]
 [10] map
    @ .\namedtuple.jl:266 [inlined]
 [11] to_cr_tangent(t::Mooncake.Tangent{@NamedTuple{…}})
    @ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\tools_for_rules.jl:330
 [12] map (repeats 9 times)
    @ .\tuple.jl:358 [inlined]
 [13] map
    @ .\namedtuple.jl:266 [inlined]
 [14] to_cr_tangent(t::Mooncake.Tangent{@NamedTuple{…}})
    @ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\tools_for_rules.jl:330
 [15] (::Mooncake.var"#pb!!#330"{…})(y_rdata::Mooncake.NoRData)
    @ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\tools_for_rules.jl:631
 [16] #solve#37
    @ C:\Users\johnb\.julia\packages\DiffEqBase\5LeiG\src\solve.jl:579 [inlined]
 [17] (::Tuple{…})(_2::Any)
    @ Base.Experimental .\<missing>:0
 [18] (::MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Any}, NTuple{9, Mooncake.NoRData}}})(x::Mooncake.NoRData)
    @ MistyClosures C:\Users\johnb\.julia\packages\MistyClosures\2vtLL\src\MistyClosures.jl:22
 [19] Pullback
    @ C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interpreter\reverse_mode.jl:957 [inlined]
 [20] solve
    @ C:\Users\johnb\.julia\packages\DiffEqBase\5LeiG\src\solve.jl:575 [inlined]
 [21] (::Tuple{Mooncake.Stack{…}, Base.RefValue{…}, Mooncake.LazyDerivedRule{…}, Mooncake.Stack{…}})(_2::Any)
    @ Base.Experimental .\<missing>:0
 [22] (::MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{…}, Tuple{…}}})(x::Mooncake.NoRData)
    @ MistyClosures C:\Users\johnb\.julia\packages\MistyClosures\2vtLL\src\MistyClosures.jl:22
 [23] Pullback
    @ C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interpreter\reverse_mode.jl:957 [inlined]
 [24] (::Mooncake.RRuleWrapperPb{Mooncake.Pullback{…}, Mooncake.LazyZeroRData{…}})(dy::Mooncake.NoRData)
    @ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interpreter\reverse_mode.jl:324
 [25] loss
    @ c:\Users\johnb\test_env\mooncake_testing.jl:14 [inlined]
 [26] (::Tuple{…})(_2::Any)
    @ Base.Experimental .\<missing>:0
 [27] (::MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Any}, Tuple{Mooncake.NoRData, Mooncake.NoRData}}})(x::Float64)
    @ MistyClosures C:\Users\johnb\.julia\packages\MistyClosures\2vtLL\src\MistyClosures.jl:22
 [28] (::Mooncake.Pullback{Tuple{…}, Tuple{…}, Tuple{…}, false, 2})(dy::Float64)
    @ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interpreter\reverse_mode.jl:957
 [29] prepare_gradient_cache(::Function, ::Vararg{Any}; config::Mooncake.Config)
    @ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interface.jl:588
 [30] prepare_gradient_cache(::Function, ::Vararg{Any})
    @ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interface.jl:583
 [31] top-level scope
    @ c:\Users\johnb\test_env\mooncake_testing.jl:21
Some type information was truncated. Use `show(err)` to see complete types.

package info

  [da2b9cff] Mooncake v0.5.7
  [1dea7af3] OrdinaryDiffEq v6.108.0
  [1ed8b502] SciMLSensitivity v7.96.0

edit: upon further reading, it looks like it’s the same problem as in this thread: ChainRules error when using Mooncake in Optimization.jl

Open an issue in SciMLSensitivity. I’m working through better Mooncake support at this time but it’s not fully robust yet