Different loss function for training ODE or DDE?

I am learning to use the DiffEqFlux pkg with the tutorial and playing around and I have a weird problem, I cannot figure it out. In the tutorial they train the model adjusting parameters of the lotka-volterra system so both populations stabilize at 1, using this loss function:

loss_rd() = sum(abs2, x-1 for x in predict_rd())

I wanted to train the model for oscillating populations with the following code

using DifferentialEquations, Plots, Flux, DiffEqFlux

function lotka_volterra(du,u,p,t)
    x, y = u
    α, β, δ, γ= p
    du[1] = dx = α*x - β*x*y
    du[2] = dy = -δ*y + γ*x*y
end

u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
p = [1.5, 1.0, 3.0, 1.0]
prob = ODEProblem(lotka_volterra,u0,tspan,p)

sol = solve(prob, Tsit5(), saveat=0.1)

A1 = sol[1,:]
A2 = sol[2,:]
t  = 0:0.1:10.0

p = [4.0,1.0,2.0,0.4]
params = Flux.params(p)

function predict_rd()
    solve(prob, Tsit5(), p=p, saveat=0.1)#[1,:]
end
#loss_rd() = sum(abs2, x-1 for x in predict_rd())
loss_rd() = sum(abs2, predict_rd() .- sol)

data = Iterators.repeated((), 1000)
opt = ADAM(0.1)
cb = function ()
    display(hcat( loss_rd(), params))
    scatter(t, A1, color=[1], label = "conejos")
    scatter!(t, A2, color=[2], label = "lobos")
    display(plot!(solve(remake(prob, p=p), Tsit5(),saveat=0.1), ylim=(0,7)))
end
cb()
Flux.train!(loss_rd, params, data, opt, cb = cb)

And it works nice. The problem is when I try the same with a DDE, the code is almost identical :

using DifferentialEquations, Plots, Flux, DiffEqFlux, DiffEqSensitivity

function delay_lotka_volterra(du,u,h,p,t)
    x, y = u
    α, β, δ, γ= p
    du[1] = dx = (α - β*y)*h(p,t-0.1)[1]
    du[2] = dy = (-γ + δ*x)*y
end

u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
p = [1.5, 1.0, 3.0, 1.0]
h(p, t) = ones(eltype(p),2)
prob = DDEProblem(delay_lotka_volterra,u0,h,tspan,constant_lags=[0.1])

data_sol = solve(prob,MethodOfSteps(Tsit5()), p=p, sensealg=TrackerAdjoint(), saveat = 0.1)
A1 = data_sol[1,:]
A2 = data_sol[2,:]
t  = 0:0.1:10.0

p = [2.2,1.0,2.0,0.4]
params = Flux.params(p)

function predict_rd_dde()
    solve(prob, MethodOfSteps(Tsit5()), p=p, sensealg=TrackerAdjoint(), saveat=0.1)#[1,:]
end
#loss_rd_dde() = sum(abs2, x-1 for x in predict_rd_dde())
loss_rd_dde() = sum(abs2, predict_rd_dde() .- data_sol)

data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
    display(loss_rd_dde())
    scatter(t, A1, color=[1], label = "conejos")
    scatter!(t, A2, color=[2], label = "lobos")
    display(plot!(solve(remake(prob, p=p), MethodOfSteps(Tsit5()),saveat=0.1), ylim=(0,6)))
end
cb()
Flux.train!(loss_rd_dde, params, data, opt, cb = cb)

But it gives the following error:

ERROR: LoadError: MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
  Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
  Float64(::T) where T<:Number at boot.jl:716
  Float64(::Irrational{:sqrtπ}) at irrationals.jl:189
  ...
Stacktrace:
 [1] convert(::Type{Float64}, ::Tracker.TrackedReal{Float64}) at ./number.jl:7
 [2] setindex!(::Array{Float64,1}, ::Tracker.TrackedReal{Float64}, ::Int64) at ./array.jl:847
 [3] (::Tracker.var"#374#376"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},TrackedArray{…,Array{Float64,1}},Tuple{Int64}})(::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/lib/array.jl:105
 [4] back_(::Tracker.Grads, ::Tracker.Call{Tracker.var"#374#376"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},TrackedArray{…,Array{Float64,1}},Tuple{Int64}},Tuple{Tracker.Tracked{Array{Float64,1}},Nothing}}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:110
 [5] back(::Tracker.Grads, ::Tracker.Tracked{Float64}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:125
 [6] #16 at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:113 [inlined]
 [7] foreach at ./abstractarray.jl:2010 [inlined]
 [8] back_(::Tracker.Grads, ::Tracker.Call{Tracker.var"#201#202",Tuple{Tracker.Tracked{Float64}}}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:113
 ... (the last 4 lines are repeated 6 more times)
 [33] back(::Tracker.Grads, ::Tracker.Tracked{Float64}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:125
 [34] (::Tracker.var"#369#370"{Tracker.Grads})(::Tracker.Tracked{Float64}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/lib/real.jl:156
 [35] foreach(::Function, ::Array{Tracker.Tracked{Float64},2}, ::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at ./abstractarray.jl:2010
 [36] back_(::Tracker.Grads, ::Tracker.Call{typeof(Tracker.collect),Tuple{Array{Tracker.Tracked{Float64},2}}}, ::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/lib/real.jl:156
 [37] back(::Tracker.Grads, ::Tracker.Tracked{Array{Float64,2}}, ::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:125
 [38] #18 at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:140 [inlined]
 [39] (::Tracker.var"#21#23"{Tracker.var"#18#19"{Tracker.Params,TrackedArray{…,Array{Float64,2}}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:149
 [40] (::DiffEqSensitivity.var"#tracker_adjoint_backpass#164"{Tuple{},Tracker.var"#21#23"{Tracker.var"#18#19"{Tracker.Params,TrackedArray{…,Array{Float64,2}}}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/DiffEqSensitivity/ZdaQE/src/local_sensitivity/concrete_solve.jl:298
 [41] #694#back at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65 [inlined]
 [42] #175 at /home/qualium/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
 [43] (::Zygote.var"#359#back#177"{Zygote.var"#175#176"{DiffEqBase.var"#694#back#474"{DiffEqSensitivity.var"#tracker_adjoint_backpass#164"{Tuple{},Tracker.var"#21#23"{Tracker.var"#18#19"{Tracker.Params,TrackedArray{…,Array{Float64,2}}}}}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [44] #solve#460 at /home/qualium/.julia/packages/DiffEqBase/V7P18/src/solve.jl:102 [inlined]
 [45] (::typeof(∂(#solve#460)))(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [46] (::Zygote.var"#175#176"{typeof(∂(#solve#460)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182
 [47] (::Zygote.var"#359#back#177"{Zygote.var"#175#176"{typeof(∂(#solve#460)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [48] (::typeof(∂(solve##kw)))(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [49] predict_rd_dde at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:25 [inlined]
 [50] (::typeof(∂(predict_rd_dde)))(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [51] loss_rd_dde at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:28 [inlined]
 [52] (::typeof(∂(loss_rd_dde)))(::Tracker.TrackedReal{Float64}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [53] #175 at /home/qualium/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
 [54] #359#back at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [55] #17 at /home/qualium/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:89 [inlined]
 [56] (::Zygote.var"#50#51"{Zygote.Params,Zygote.Context,typeof(∂(#17))})(::Tracker.TrackedReal{Float64}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:177
 [57] gradient(::Function, ::Zygote.Params) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:54
 [58] macro expansion at /home/qualium/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:88 [inlined]
 [59] macro expansion at /home/qualium/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [60] train!(::typeof(loss_rd_dde), ::Zygote.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM; cb::var"#15#16") at /home/qualium/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
 [61] top-level scope at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:39
 [62] include(::String) at ./client.jl:457
in expression starting at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:39

I’m almost begginer in Julia so I don’t understand very well the error messages, but I checked typeof the variables and they are the same type, numbers are all Float64 etc.
I suppose it has to do with the new pkg DiffEqSensitivity or how it integrates using TrackerAdjoint, but I can’t make sense of any of this.

Any thoughts? Thank you

Any reason why TrackerAdjoint instead of ReverseDiffAdjoint? ReverseDiffAdjoint works well here, so I think it’s a Tracker issue.

If I’m honest, I was just copying the examples of tutorials and changing bits to understand how it works. I’ve read AD page but I’m too noob to understand it beyond there are different ways to differentiate, I’ll have to dig deeper on that.

Anyway, changing sensealg to ReverseDiffAjoint() gives the following error:

ERROR: LoadError: MethodError: no method matching increment_deriv!(::ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::Array{Float64,1})
Closest candidates are:
  increment_deriv!(::ReverseDiff.TrackedArray, ::AbstractArray, ::Any) at /home/qualium/.julia/packages/ReverseDiff/vScHI/src/derivatives/propagation.jl:33
  increment_deriv!(::AbstractArray, ::AbstractArray, ::Any) at /home/qualium/.julia/packages/ReverseDiff/vScHI/src/derivatives/propagation.jl:35
  increment_deriv!(::AbstractArray, ::Any) at /home/qualium/.julia/packages/ReverseDiff/vScHI/src/derivatives/propagation.jl:38
  ...
Stacktrace:
 [1] increment_deriv! at /home/qualium/.julia/packages/ReverseDiff/vScHI/src/derivatives/propagation.jl:35 [inlined]
 [2] increment_deriv!(::Array{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},2}, ::VectorOfArray{Float64,2,Array{Array{Float64,1},1}}) at /home/qualium/.julia/packages/ReverseDiff/vScHI/src/derivatives/propagation.jl:40
 [3] reversediff_adjoint_backpass at /home/qualium/.julia/packages/DiffEqSensitivity/ZdaQE/src/local_sensitivity/concrete_solve.jl:354 [inlined]
 [4] #694#back at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65 [inlined]
 [5] #175 at /home/qualium/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
 [6] (::Zygote.var"#359#back#177"{Zygote.var"#175#176"{DiffEqBase.var"#694#back#474"{DiffEqSensitivity.var"#reversediff_adjoint_backpass#171"{DDEProblem{Array{Float64,1},Tuple{Float64,Float64},Array{Float64,1},Tuple{},true,DiffEqBase.NullParameters,DDEFunction{true,typeof(delay_lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(h),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},Tuple{},ReverseDiff.GradientTape{DiffEqSensitivity.var"#reversediff_adjoint_forwardpass#168"{Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:saveat,),Tuple{Float64}}},DDEProblem{Array{Float64,1},Tuple{Float64,Float64},Array{Float64,1},Tuple{},true,DiffEqBase.NullParameters,DDEFunction{true,typeof(delay_lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(h),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},MethodOfSteps{Tsit5,NLFunctional{Rational{Int64},Rational{Int64}},false},Tuple{}},Tuple{ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},Array{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},2}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}},Array{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},2}}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::VectorOfArray{Float64,2,Array{Array{Float64,1},1}}) at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [7] #solve#460 at /home/qualium/.julia/packages/DiffEqBase/V7P18/src/solve.jl:102 [inlined]
 [8] (::typeof(∂(#solve#460)))(::VectorOfArray{Float64,2,Array{Array{Float64,1},1}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [9] (::Zygote.var"#175#176"{typeof(∂(#solve#460)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::VectorOfArray{Float64,2,Array{Array{Float64,1},1}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182
 [10] (::Zygote.var"#359#back#177"{Zygote.var"#175#176"{typeof(∂(#solve#460)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::VectorOfArray{Float64,2,Array{Array{Float64,1},1}}) at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [11] (::typeof(∂(solve##kw)))(::VectorOfArray{Float64,2,Array{Array{Float64,1},1}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [12] predict_rd_dde at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:25 [inlined]
 [13] (::typeof(∂(predict_rd_dde)))(::VectorOfArray{Float64,2,Array{Array{Float64,1},1}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [14] loss_rd_dde at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:28 [inlined]
 [15] (::typeof(∂(loss_rd_dde)))(::Float64) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [16] #175 at /home/qualium/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
 [17] #359#back at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [18] #17 at /home/qualium/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:89 [inlined]
 [19] (::Zygote.var"#50#51"{Zygote.Params,Zygote.Context,typeof(∂(#17))})(::Float64) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:177
 [20] gradient(::Function, ::Zygote.Params) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:54
 [21] macro expansion at /home/qualium/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:88 [inlined]
 [22] macro expansion at /home/qualium/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [23] train!(::typeof(loss_rd_dde), ::Zygote.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM; cb::var"#19#20") at /home/qualium/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
 [24] top-level scope at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:39
 [25] include(::String) at ./client.jl:457
 [26] top-level scope at REPL[2]:1
in expression starting at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:39

The only thing I can think of is that the data used to calculate the loss should be given in other format but. Also I should note that, if I keep only one of the outputs of solve like

solve(prob, MethodOfSteps(Tsit5()), p=p, sensealg=ReverseDiffAdjoint(), saveat=0.1)[1,:]

and the loss function is given by

loss_rd_dde() = sum(abs2, x-1 for x in predict_rd_dde())

it works perfectly, both with TrackerAdjoint and ReverseDiffAdjoint. It is just not that interesting, because you’re looking for a population to stabilize at 1, that’s what I tried to match a population changing in time. But this won’t work, both in DDE or SDE.

Seems like changing the data_sol type to Array solves the problem. I don’t understand why it wants the real solution and the prediction in different types, but it works.

You’re just on the versions before that was fixed. ReverseDiffAdjoint is fine on the latest releases here.