Zygote and ODESolution Manipulation/Indexing Compatibility

Here is a solution using symbolic indexing:

using ModelingToolkit, DifferentialEquations, Zygote, SymbolicIndexingInterface, SciMLSensitivity
using SciMLStructures: replace, replace!, Tunable

using ModelingToolkit: t_nounits as t, D_nounits as D


function make_problem()
    @variables u(t)
    @parameters k [tunable=false]
    @parameters c [tunable=true]

    eqs = [D(u) ~ k*sin(c*t)]

    @mtkbuild sys = ODESystem(eqs, t)

    u0 = 0

    odeprob = ODEProblem(sys, [u => u0], (0, 1), [k =>1.0, c => 5])
end

function myloss(newps, odeprob)
    #set just tunable parameters
    ps = parameter_values(odeprob)
    ps = replace(Tunable(), ps, newps)
    newprob = remake(odeprob, p = ps)
    newsol = solve(newprob);

    #tvals = newsol.t #does not work, gives addition error
    tvals = [0.0, 9.999999999999999e-5, 0.0010999999999999998, 0.011099999999999997,
    0.040385829282772054, 0.08289185228479545, 0.13651127701877777, 0.2049194533112576,
    0.2887710337916548, 0.39568436269877305, 0.5152310667421693, 0.631415423924754,
    0.7606225650456293, 0.8853311676455744, 1.0] #works, but is inconvenient and requires knowing

    #yvals = newsol(tvals, idxs=u).u #does not work, gives mutating array error
    yvals = Array(newsol)' #works, but with more complicated solutions is quite inconvenient
    loss = sum((yvals .- -0.1*cos.(10.0*tvals) .+ 0.1).^2)
end

function newloss(newps, odeprob)
    ps = parameter_values(odeprob)
    ps = replace(Tunable(), ps, newps)
    newprob = remake(odeprob, p = ps)
    newsol = solve(newprob);
    u = newprob.f.sys.u # this is probably the wrong way to do this, but
                        # I don't know the correct way.
    sum(abs2, newsol[u-0.1*cos(10.0*t)+0.1])
 end

function test()
    odeprob = make_problem()
    sol = solve(odeprob)
    @time myloss([10.0], odeprob)
    @time newloss([10.0], odeprob)
    @time Zygote.gradient(p->myloss(p, odeprob), [10.0])
    @time Zygote.gradient(p->newloss(p, odeprob), [10.0])
end
1 Like