I’m trying to write fairly complicated loss functions which tend to require manipulation of the ODESolution
object.
I would like to perform reverse mode differentiation using Zygote.jl over my loss function so I can obtain gradients for a large number of parameters at once, but I have been hindered by a couple errors in accessing parts of the solution. In particular, I am interested in retrieving sol.t
, and indexing as with sol(sol.t, idxs=u).u
, both to obtain vectors or matrices that I can manipulate further.
A small example is shown below which can be easily modified to show both errors.
Is there a different way I’m supposed to go about flexibly manipulating the solutions, or do I need to write specific and more rigid solve
statements with save_idxs
and saveat
keywords to avoid performing post-solve solution indexing and time-value retrival?
Thanks!
Example
using ModelingToolkit, DifferentialEquations, Zygote, SymbolicIndexingInterface, SciMLSensitivity
using SciMLStructures: replace, replace!, Tunable
using ModelingToolkit: t_nounits as t, D_nounits as D
@variables u(t)
@parameters k [tunable=false]
@parameters c [tunable=true]
eqs = [D(u) ~ k*sin(c*t)]
@mtkbuild sys = ODESystem(eqs, t)
u0 = 0
odeprob = ODEProblem(sys, [u => u0], (0, 1), [k =>1.0, c => 5])
sol = solve(odeprob)
function myloss(newps)
#set just tunable parameters
ps = parameter_values(odeprob)
ps = replace(Tunable(), ps, newps)
newprob = remake(odeprob, p = ps)
newsol = solve(newprob);
#tvals = newsol.t #does not work, gives addition error
tvals = [0.0, 9.999999999999999e-5, 0.0010999999999999998, 0.011099999999999997,
0.040385829282772054, 0.08289185228479545, 0.13651127701877777, 0.2049194533112576,
0.2887710337916548, 0.39568436269877305, 0.5152310667421693, 0.631415423924754,
0.7606225650456293, 0.8853311676455744, 1.0] #works, but is inconvenient and requires knowing sol.t in advance
#yvals = newsol(tvals, idxs=u).u #does not work, gives mutating array error
yvals = Array(newsol)' #works, but with more complicated solutions is quite inconvenient
loss = sum((yvals .- -0.1*cos.(10.0*tvals) .+ 0.1).^2)
end
@time myloss([10.0]) #works in all cases and gives the same result
@time Zygote.gradient(myloss, [10.0]) #only works with second tvals, yvals definitions
by swapping out the tvals and yvals definitions, the errors I’ve encountered can be replicated.
tvals error:
ERROR: MethodError: no method matching +(::ODESolution{…}, ::@NamedTuple{…})
The function `+` exists, but no method is defined for this combination of argument types.
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:596
+(::ChainRulesCore.NotImplemented, ::Any)
@ ChainRulesCore C:\Users\johnb\.julia\packages\ChainRulesCore\U6wNx\src\tangent_arithmetic.jl:24
+(::Any, ::ChainRulesCore.NoTangent)
@ ChainRulesCore C:\Users\johnb\.julia\packages\ChainRulesCore\U6wNx\src\tangent_arithmetic.jl:60
...
Stacktrace:
[1] accum(x::ODESolution{…}, y::@NamedTuple{…})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\lib\lib.jl:17
[2] getproperty
@ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:145 [inlined]
[3] myloss
@ .\Untitled-1:28 [inlined]
[4] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
[5] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:91
[6] gradient(f::Function, args::Vector{Float64})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:148
[7] macro expansion
@ .\timing.jl:581 [inlined]
[8] top-level scope
@ .\Untitled-1:43
yvals error:
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Vector{Float64}}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:35
[2] _throw_mutation_error(f::Function, args::Vector{Vector{Float64}})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\lib\array.jl:70
[3] (::Zygote.var"#544#545"{Vector{Vector{Float64}}})(::Nothing)
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\lib\array.jl:82
[4] (::Zygote.var"#2623#back#546"{Zygote.var"#544#545"{Vector{Vector{Float64}}}})(Δ::Nothing)
@ Zygote C:\Users\johnb\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:72
[5] invpermute!
@ .\combinatorics.jl:237 [inlined]
[6] (::Zygote.Pullback{Tuple{typeof(invpermute!), Vector{…}, Vector{…}}, Tuple{Zygote.var"#2623#back#546"{…}}})(Δ::Nothing)
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
[7] ode_interpolation
@ C:\Users\johnb\.julia\packages\OrdinaryDiffEqCore\33WQj\src\dense\generic_dense.jl:584 [inlined]
[8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.RefValue{Any})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
[9] InterpolationData
@ C:\Users\johnb\.julia\packages\OrdinaryDiffEqCore\33WQj\src\interp_func.jl:46 [inlined]
[10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.RefValue{Any})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
[11] AbstractODESolution
@ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:331 [inlined]
[12] (::Zygote.Pullback{Tuple{ODESolution{…}, Vector{…}, Type{…}, Num, Symbol}, Any})(Δ::Base.RefValue{Any})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
[13] #_#543
@ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:221 [inlined]
[14] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
[15] AbstractODESolution
@ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:216 [inlined]
[16] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
[17] AbstractODESolution
@ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:216 [inlined]
[18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.RefValue{Any})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
[19] myloss
@ .\Untitled-1:34 [inlined]
[20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
[21] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:91
[22] gradient(f::Function, args::Vector{Float64})
@ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:148
[23] macro expansion
@ .\timing.jl:581 [inlined]
[24] top-level scope
@ .\Untitled-1:43
edited to remove unnecessary using Plots
line I was using to look at the solutions