Zygote and ODESolution Manipulation/Indexing Compatibility

I’m trying to write fairly complicated loss functions which tend to require manipulation of the ODESolution object.

I would like to perform reverse mode differentiation using Zygote.jl over my loss function so I can obtain gradients for a large number of parameters at once, but I have been hindered by a couple errors in accessing parts of the solution. In particular, I am interested in retrieving sol.t, and indexing as with sol(sol.t, idxs=u).u, both to obtain vectors or matrices that I can manipulate further.

A small example is shown below which can be easily modified to show both errors.

Is there a different way I’m supposed to go about flexibly manipulating the solutions, or do I need to write specific and more rigid solve statements with save_idxs and saveat keywords to avoid performing post-solve solution indexing and time-value retrival?

Thanks!

Example

using ModelingToolkit, DifferentialEquations, Zygote, SymbolicIndexingInterface, SciMLSensitivity
using SciMLStructures: replace, replace!, Tunable

using ModelingToolkit: t_nounits as t, D_nounits as D


@variables u(t)
@parameters k [tunable=false]
@parameters c [tunable=true]

eqs = [D(u) ~ k*sin(c*t)]

@mtkbuild sys = ODESystem(eqs, t)

u0 = 0

odeprob = ODEProblem(sys, [u => u0], (0, 1), [k =>1.0, c => 5])
sol = solve(odeprob)


function myloss(newps)
    #set just tunable parameters
    ps = parameter_values(odeprob)
    ps = replace(Tunable(), ps, newps)
    newprob = remake(odeprob, p = ps)
    newsol = solve(newprob);
    
    #tvals = newsol.t #does not work, gives addition error
    tvals = [0.0, 9.999999999999999e-5, 0.0010999999999999998, 0.011099999999999997, 
    0.040385829282772054, 0.08289185228479545, 0.13651127701877777, 0.2049194533112576, 
    0.2887710337916548, 0.39568436269877305, 0.5152310667421693, 0.631415423924754, 
    0.7606225650456293, 0.8853311676455744, 1.0] #works, but is inconvenient and requires knowing sol.t in advance
    
    #yvals = newsol(tvals, idxs=u).u #does not work, gives mutating array error
    yvals = Array(newsol)' #works, but with more complicated solutions is quite inconvenient
    loss = sum((yvals .- -0.1*cos.(10.0*tvals) .+ 0.1).^2)
end

@time myloss([10.0]) #works in all cases and gives the same result

@time Zygote.gradient(myloss, [10.0]) #only works with second tvals, yvals definitions

by swapping out the tvals and yvals definitions, the errors I’ve encountered can be replicated.

tvals error:

ERROR: MethodError: no method matching +(::ODESolution{…}, ::@NamedTuple{…})
The function `+` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...)
   @ Base operators.jl:596
  +(::ChainRulesCore.NotImplemented, ::Any)
   @ ChainRulesCore C:\Users\johnb\.julia\packages\ChainRulesCore\U6wNx\src\tangent_arithmetic.jl:24
  +(::Any, ::ChainRulesCore.NoTangent)
   @ ChainRulesCore C:\Users\johnb\.julia\packages\ChainRulesCore\U6wNx\src\tangent_arithmetic.jl:60
  ...

Stacktrace:
 [1] accum(x::ODESolution{…}, y::@NamedTuple{…})
   @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\lib\lib.jl:17
 [2] getproperty
   @ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:145 [inlined]
 [3] myloss
   @ .\Untitled-1:28 [inlined]
 [4] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
   @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [5] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
   @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:91
 [6] gradient(f::Function, args::Vector{Float64})
   @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:148
 [7] macro expansion
   @ .\timing.jl:581 [inlined]
 [8] top-level scope
   @ .\Untitled-1:43

yvals error:

ERROR: Mutating arrays is not supported -- called setindex!(Vector{Vector{Float64}}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Vector{Float64}})
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\lib\array.jl:70
  [3] (::Zygote.var"#544#545"{Vector{Vector{Float64}}})(::Nothing)
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\lib\array.jl:82
  [4] (::Zygote.var"#2623#back#546"{Zygote.var"#544#545"{Vector{Vector{Float64}}}})(Δ::Nothing)
    @ Zygote C:\Users\johnb\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:72
  [5] invpermute!
    @ .\combinatorics.jl:237 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(invpermute!), Vector{…}, Vector{…}}, Tuple{Zygote.var"#2623#back#546"{…}}})(Δ::Nothing)
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
  [7] ode_interpolation
    @ C:\Users\johnb\.julia\packages\OrdinaryDiffEqCore\33WQj\src\dense\generic_dense.jl:584 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.RefValue{Any})
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
  [9] InterpolationData
    @ C:\Users\johnb\.julia\packages\OrdinaryDiffEqCore\33WQj\src\interp_func.jl:46 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.RefValue{Any})
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [11] AbstractODESolution
    @ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:331 [inlined]
 [12] (::Zygote.Pullback{Tuple{ODESolution{…}, Vector{…}, Type{…}, Num, Symbol}, Any})(Δ::Base.RefValue{Any})
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [13] #_#543
    @ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:221 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [15] AbstractODESolution
    @ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:216 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [17] AbstractODESolution
    @ C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\src\solutions\ode_solutions.jl:216 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.RefValue{Any})
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [19] myloss
    @ .\Untitled-1:34 [inlined]
 [20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [21] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:91
 [22] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:148
 [23] macro expansion
    @ .\timing.jl:581 [inlined]
 [24] top-level scope
    @ .\Untitled-1:43

edited to remove unnecessary using Plots line I was using to look at the solutions

Here is a solution using symbolic indexing:

using ModelingToolkit, DifferentialEquations, Zygote, SymbolicIndexingInterface, SciMLSensitivity
using SciMLStructures: replace, replace!, Tunable

using ModelingToolkit: t_nounits as t, D_nounits as D


function make_problem()
    @variables u(t)
    @parameters k [tunable=false]
    @parameters c [tunable=true]

    eqs = [D(u) ~ k*sin(c*t)]

    @mtkbuild sys = ODESystem(eqs, t)

    u0 = 0

    odeprob = ODEProblem(sys, [u => u0], (0, 1), [k =>1.0, c => 5])
end

function myloss(newps, odeprob)
    #set just tunable parameters
    ps = parameter_values(odeprob)
    ps = replace(Tunable(), ps, newps)
    newprob = remake(odeprob, p = ps)
    newsol = solve(newprob);

    #tvals = newsol.t #does not work, gives addition error
    tvals = [0.0, 9.999999999999999e-5, 0.0010999999999999998, 0.011099999999999997,
    0.040385829282772054, 0.08289185228479545, 0.13651127701877777, 0.2049194533112576,
    0.2887710337916548, 0.39568436269877305, 0.5152310667421693, 0.631415423924754,
    0.7606225650456293, 0.8853311676455744, 1.0] #works, but is inconvenient and requires knowing

    #yvals = newsol(tvals, idxs=u).u #does not work, gives mutating array error
    yvals = Array(newsol)' #works, but with more complicated solutions is quite inconvenient
    loss = sum((yvals .- -0.1*cos.(10.0*tvals) .+ 0.1).^2)
end

function newloss(newps, odeprob)
    ps = parameter_values(odeprob)
    ps = replace(Tunable(), ps, newps)
    newprob = remake(odeprob, p = ps)
    newsol = solve(newprob);
    u = newprob.f.sys.u # this is probably the wrong way to do this, but
                        # I don't know the correct way.
    sum(abs2, newsol[u-0.1*cos(10.0*t)+0.1])
 end

function test()
    odeprob = make_problem()
    sol = solve(odeprob)
    @time myloss([10.0], odeprob)
    @time newloss([10.0], odeprob)
    @time Zygote.gradient(p->myloss(p, odeprob), [10.0])
    @time Zygote.gradient(p->newloss(p, odeprob), [10.0])
end
1 Like

I had completely forgotten about solution slicing! And obtaining symbolic u from the problem is much more useful than what I have been doing. That seems like it works well for dependent variables like u(t), but oddly enough, it does not work for t, as in newsol[t].

This gives an interesting error if you try to obtain tvals as tvals = newsol[t]. Fortunately, obtaining tvals in advance is much easier than obtaining yvals in advance, so this solution works for what I need it to do.

Thanks a lot!

tvals = newsol[t] error for reference:

ERROR: Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] (::SciMLBaseZygoteExt.var"#ODESolution_getindex_pullback#59"{Num})(Δ::Vector{Float64})
    @ SciMLBaseZygoteExt C:\Users\johnb\.julia\packages\SciMLBase\Pma4a\ext\SciMLBaseZygoteExt.jl:119
  [3] #89#back
    @ C:\Users\johnb\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:72 [inlined]
  [4] myloss_hybrid
    @ .\Untitled-1:63 [inlined]
  [5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
  [6] #25
    @ .\Untitled-1:74 [inlined]
  [7] (::Zygote.Pullback{Tuple{var"#25#26", Vector{…}}, Tuple{Zygote.var"#1986#back#198"{…}, Zygote.Pullback{…}}})(Δ::Float64)
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
  [8] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:91
  [9] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\johnb\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:148
 [10] macro expansion
    @ .\timing.jl:581 [inlined]
 [11] top-level scope
    @ .\Untitled-1:74

If you need the times, it is probably better to specify them with saveat

newsol = solve(newprob; saveat=0.0:0.05:1.0);

This prevents adaptive time stepping changing where the loss is evaluated.