Correcting ODE initial conditions using automatic differentiation

Hi all, I’m trying to find a way to correct initial conditions of a system of ODEs by using automatic differentiation through the ODE solver. I have implemented a similar algorithm in Python using TorchDIffEq but I’m new to Julia and having trouble debugging the errors. I’ve gone through the docs and examples of both Zygote and ReverseDiff (which I think are best suited to my purpose) but I’m having difficulty adapting them to my problem.

What I want to do is first solve the system of ODEs with an initial guess for the state at time t0 and hence find the state at some later time t1. The loss function is the absolute difference between the actual and predicted states at t1. The problem is now to find the gradients of the loss function with respect to the initial state and hence correct the initial state (like differential correction). However, I am having trouble finding the gradients of the loss function. Minimal working example is shown below:

function twobody(du,u,t,p)
    r = norm(u[1:3])
    mu = p
    du[1] = dx = u[4]
    du[2] = dy = u[5]
    du[3] = dz = u[6]
    du[4] = dvx = -mu*u[1]/(r^3)
    du[5] = dvy = -mu*u[2]/(r^3)
    du[6] = dvz = -mu*u[3]/(r^3)
end

# known quantities
r1 = [1., 0., 0.]
r2 = [1., 1/8, 1/8]
Δt = 0.125
p = 1.

v1 = [0.05, 1., 1.]  #initial guess for v1
u0 = [r1; v1] # initial state

prob = ODEProblem(twobody, u0, (0,Δt), p)
sol = solve(prob, reltol=1e-8, abstol=1e-8)

function loss_r(v)
    u = [r1; v]
    _prob = remake(prob, u0=u)
    sol = solve(_prob,Tsit5(),saveat=Δt,sensealg=QuadratureAdjoint())
    norm(r2-sol[end][1:3])
end

dLdv = Zygote.gradient(loss_r, v1)

The last line gives me the error:

ERROR: MethodError: no method matching similar(::Float64)
Closest candidates are:
  similar(::JuliaInterpreter.Compiled, ::Any) at C:\Users\Komal\.julia\packages\JuliaInterpreter\Eyi3R\src\types.jl:7
  similar(::Sundials.NVector) at C:\Users\Komal\.julia\packages\Sundials\YfkdE\src\nvector_wrapper.jl:72
  similar(::Array{T,1}) where T at array.jl:375
  ...
Stacktrace:
 [1] (::ReverseDiff.var"#657#658"{Float64,Array{ReverseDiff.AbstractInstruction,1}})(::Float64) at C:\Users\Komal\.julia\packages\ReverseDiff\E4Tzn\src\api\Config.jl:46
 [2] map(::ReverseDiff.var"#657#658"{Float64,Array{ReverseDiff.AbstractInstruction,1}}, ::Tuple{Array{Float64,1},Float64,Array{Float64,1}}) at .\tuple.jl:159
 [3] ReverseDiff.GradientConfig(::Tuple{Array{Float64,1},Float64,Array{Float64,1}}, ::Type{Float64}, ::Array{ReverseDiff.AbstractInstruction,1}) at C:\Users\Komal\.julia\packages\ReverseDiff\E4Tzn\src\api\Config.jl:46
 [4] ReverseDiff.GradientConfig(::Tuple{Array{Float64,1},Float64,Array{Float64,1}}, ::Array{ReverseDiff.AbstractInstruction,1}) at C:\Users\Komal\.julia\packages\ReverseDiff\E4Tzn\src\api\Config.jl:37 (repeats 2 times)
Array{Float64,1},Float64,Array{Float64,1}}) at C:\Users\Komal\.julia\packages\ReverseDiff\E4Tzn\src\api\tape.jl:204
 [6] adjointdiffcache(::Function, ::QuadratureAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Float64,ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::Nothing, ::ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing}; quad::Bool, noiseterm::Bool) at 
C:\Users\Komal\.julia\packages\DiffEqSensitivity\EpC0d\src\adjoint_common.jl:144
 [7] DiffEqSensitivity.ODEQuadratureAdjointSensitivityFunction(::Function, ::QuadratureAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Float64,ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::Nothing) at C:\Users\Komal\.julia\packages\DiffEqSensitivity\EpC0d\src\quadrature_adjoint.jl:12
 [8] ODEAdjointProblem(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Float64,ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::QuadratureAdjoint{0,true,Val{:central},Bool}, ::Function, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing, ::Nothing) at C:\Users\Komal\.julia\packages\DiffEqSensitivity\EpC0d\src\quadrature_adjoint.jl:60
 [9] _adjoint_sensitivities(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Float64,ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::QuadratureAdjoint{0,true,Val{:central},Bool}, ::Tsit5, ::DiffEqSensitivity.var"#df#178"{Array{Array{Float64,1},1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; abstol::Float64, reltol::Float64, callback::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at C:\Users\Komal\.julia\packages\DiffEqSensitivity\EpC0d\src\quadrature_adjoint.jl:192
 [10] adjoint_sensitivities(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Float64,ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(twobody),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N; sensealg::QuadratureAdjoint{0,true,Val{:central},Bool}, kwargs::Base.Iterators.Pairs{Symbol,Nothing,Tuple{Symbol},NamedTuple{(:callback,),Tuple{Nothing}}}) at C:\Users\Komal\.julia\packages\DiffEqSensitivity\EpC0d\src\sensitivity_interface.jl:6
 [11] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#177"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Tsit5,QuadratureAdjoint{0,true,Val{:central},Bool},Array{Float64,1},Float64,Tuple{},NamedTuple{(),Tuple{}},Colon})(::Array{Array{Float64,1},1}) at C:\Users\Komal\.julia\packages\DiffEqSensitivity\EpC0d\src\concrete_solve.jl:183 [12] ZBack at C:\Users\Komal\.julia\packages\Zygote\ggM8Z\src\compiler\chainrules.jl:77 [inlined]
 [13] (::Zygote.var"#kw_zpullback#40"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#177"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Tsit5,QuadratureAdjoint{0,true,Val{:central},Bool},Array{Float64,1},Float64,Tuple{},NamedTuple{(),Tuple{}},Colon}})(::Array{Array{Float64,1},1}) at C:\Users\Komal\.julia\packages\Zygote\ggM8Z\src\compiler\chainrules.jl:103
 [14] (::Zygote.var"#150#151"{Zygote.var"#kw_zpullback#40"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#177"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Tsit5,QuadratureAdjoint{0,true,Val{:central},Bool},Array{Float64,1},Float64,Tuple{},NamedTuple{(),Tuple{}},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Array{Float64,1},1}) at C:\Users\Komal\.julia\packages\Zygote\ggM8Z\src\lib\lib.jl:191
 [15] (::Zygote.var"#1733#back#152"{Zygote.var"#150#151"{Zygote.var"#kw_zpullback#40"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#177"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Tsit5,QuadratureAdjoint{0,true,Val{:central},Bool},Array{Float64,1},Float64,Tuple{},NamedTuple{(),Tuple{}},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Array{Float64,1},1}) at C:\Users\Komal\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [16] #solve#57 at C:\Users\Komal\.julia\packages\DiffEqBase\QiFNl\src\solve.jl:70 [inlined]
 [17] (::typeof(∂(#solve#57)))(::Array{Array{Float64,1},1}) at C:\Users\Komal\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [18] (::Zygote.var"#150#151"{typeof(∂(#solve#57)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Array{Float64,1},1}) at C:\Users\Komal\.julia\packages\Zygote\ggM8Z\src\lib\lib.jl:191
 [19] (::Zygote.var"#1733#back#152"{Zygote.var"#150#151"{typeof(∂(#solve#57)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Array{Float64,1},1}) at C:\Users\Komal\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [20] (::typeof(∂(solve##kw)))(::Array{Array{Float64,1},1}) at C:\Users\Komal\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [21] loss_r at C:\Repositories\neuralODEs\lambert_problem.jl:44 [inlined]
 [22] (::typeof(∂(loss_r)))(::Float64) at C:\Users\Komal\.julia\packages\Zygote\ggM8Z\src\compiler\interface2.jl:0
 [23] (::Zygote.var"#41#42"{typeof(∂(loss_r))})(::Float64) at C:\Users\Komal\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:40
 [24] gradient(::Function, ::Array{Float64,1}) at C:\Users\Komal\.julia\packages\Zygote\ggM8Z\src\compiler\interface.jl:49
 [25] top-level scope at C:\Repositories\neuralODEs\lambert_problem.jl:57

I’ve gone all over the docs, examples and forums several times but I haven’t encountered a similar issue. I’ve also tried computing the gradients using ReverseDiff and then using Flux, but I get the same error every time. There’s something wrong with my loss function definition but I can’t figure out what. Any help would be greatly appreciated.

Try

function twobody(du,u,p,t)
    r = norm(u[1:3])
    mu = p[1]
    du[1] = dx = u[4]
    du[2] = dy = u[5]
    du[3] = dz = u[6]
    du[4] = dvx = -mu*u[1]/(r^3)
    du[5] = dvy = -mu*u[2]/(r^3)
    du[6] = dvz = -mu*u[3]/(r^3)
end

# known quantities
r1 = [1., 0., 0.]
r2 = [1., 1/8, 1/8]
Δt = 0.125
p = [1.]

v1 = [0.05, 1., 1.]  #initial guess for v1
u0 = [r1; v1] # initial state

prob = ODEProblem(twobody, u0, (0,Δt), p)
sol = solve(prob, reltol=1e-8, abstol=1e-8)

function loss_r(v)
    u = [r1; v]
    _prob = remake(prob, u0=u)
    sol = solve(_prob,Tsit5(),saveat=Δt,sensealg=QuadratureAdjoint())
    norm(r2-sol[end][1:3])
end

dLdv = Zygote.gradient(loss_r, v1)
5 Likes

My god, such a tiny problem. Coming from Python it’s so easy to overlook. Thanks a lot for your help!

2 Likes