I’m trying out Mooncake.jl to obtain gradients through ODEs, but I keep getting uninitialized tangent issues and I’m not sure how to fix it.
I’ve attached a small example problem below. Using DifferentiationInterface.jl gave the same error, so to simplify I’m just using the Mooncake functions.
ForwardDiff.jl and Zygote.jl both work without any modification, so it feels like I’m missing something obvious since this seems like it should be very simple. I’ve tried on 1.11 and 1.12 with the same issue in both cases.
using OrdinaryDiffEq, SciMLSensitivity, Mooncake
function testode!(du, u, p, t)
du[1] = p[1]-u[1]
end
u0 = [1.0];
params = [0.0]
prob = ODEProblem(testode!, u0, (0.0, 1.0), params);
sol = solve(prob);
#test loss function
function loss(ps)
newprob = remake(prob, p=ps)
newsol = solve(newprob)
return sum(Array(newsol))
end
loss(params) #this works
diffcache = prepare_gradient_cache(loss, params) #this fails
Error message:
ERROR: Trying to convert uninitialized tangent to ChainRules tangent.
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:35
[2] to_cr_tangent
@ C:\Users\johnb\.julia\packages\LinearSolve\re36c\ext\LinearSolveMooncakeExt.jl:30 [inlined]
[3] map (repeats 4 times)
@ .\tuple.jl:358 [inlined]
[4] map(f::typeof(Mooncake.to_cr_tangent), t::Tuple{…})
@ Base .\tuple.jl:358
[5] map
@ .\tuple.jl:358 [inlined]
[6] map(::Function, ::@NamedTuple{…})
@ Base .\namedtuple.jl:266
[7] to_cr_tangent(t::Mooncake.MutableTangent{@NamedTuple{…}})
@ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\tools_for_rules.jl:331
[8] map
@ .\tuple.jl:357 [inlined]
[9] map (repeats 6 times)
@ .\tuple.jl:358 [inlined]
[10] map
@ .\namedtuple.jl:266 [inlined]
[11] to_cr_tangent(t::Mooncake.Tangent{@NamedTuple{…}})
@ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\tools_for_rules.jl:330
[12] map (repeats 9 times)
@ .\tuple.jl:358 [inlined]
[13] map
@ .\namedtuple.jl:266 [inlined]
[14] to_cr_tangent(t::Mooncake.Tangent{@NamedTuple{…}})
@ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\tools_for_rules.jl:330
[15] (::Mooncake.var"#pb!!#330"{…})(y_rdata::Mooncake.NoRData)
@ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\tools_for_rules.jl:631
[16] #solve#37
@ C:\Users\johnb\.julia\packages\DiffEqBase\5LeiG\src\solve.jl:579 [inlined]
[17] (::Tuple{…})(_2::Any)
@ Base.Experimental .\<missing>:0
[18] (::MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Any}, NTuple{9, Mooncake.NoRData}}})(x::Mooncake.NoRData)
@ MistyClosures C:\Users\johnb\.julia\packages\MistyClosures\2vtLL\src\MistyClosures.jl:22
[19] Pullback
@ C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interpreter\reverse_mode.jl:957 [inlined]
[20] solve
@ C:\Users\johnb\.julia\packages\DiffEqBase\5LeiG\src\solve.jl:575 [inlined]
[21] (::Tuple{Mooncake.Stack{…}, Base.RefValue{…}, Mooncake.LazyDerivedRule{…}, Mooncake.Stack{…}})(_2::Any)
@ Base.Experimental .\<missing>:0
[22] (::MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{…}, Tuple{…}}})(x::Mooncake.NoRData)
@ MistyClosures C:\Users\johnb\.julia\packages\MistyClosures\2vtLL\src\MistyClosures.jl:22
[23] Pullback
@ C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interpreter\reverse_mode.jl:957 [inlined]
[24] (::Mooncake.RRuleWrapperPb{Mooncake.Pullback{…}, Mooncake.LazyZeroRData{…}})(dy::Mooncake.NoRData)
@ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interpreter\reverse_mode.jl:324
[25] loss
@ c:\Users\johnb\test_env\mooncake_testing.jl:14 [inlined]
[26] (::Tuple{…})(_2::Any)
@ Base.Experimental .\<missing>:0
[27] (::MistyClosures.MistyClosure{Core.OpaqueClosure{Tuple{Any}, Tuple{Mooncake.NoRData, Mooncake.NoRData}}})(x::Float64)
@ MistyClosures C:\Users\johnb\.julia\packages\MistyClosures\2vtLL\src\MistyClosures.jl:22
[28] (::Mooncake.Pullback{Tuple{…}, Tuple{…}, Tuple{…}, false, 2})(dy::Float64)
@ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interpreter\reverse_mode.jl:957
[29] prepare_gradient_cache(::Function, ::Vararg{Any}; config::Mooncake.Config)
@ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interface.jl:588
[30] prepare_gradient_cache(::Function, ::Vararg{Any})
@ Mooncake C:\Users\johnb\.julia\packages\Mooncake\vKe16\src\interface.jl:583
[31] top-level scope
@ c:\Users\johnb\test_env\mooncake_testing.jl:21
Some type information was truncated. Use `show(err)` to see complete types.
package info
[da2b9cff] Mooncake v0.5.7
[1dea7af3] OrdinaryDiffEq v6.108.0
[1ed8b502] SciMLSensitivity v7.96.0
edit: upon further reading, it looks like it’s the same problem as in this thread: ChainRules error when using Mooncake in Optimization.jl