I am learning to use the DiffEqFlux pkg with the tutorial and playing around and I have a weird problem, I cannot figure it out. In the tutorial they train the model adjusting parameters of the lotka-volterra system so both populations stabilize at 1, using this loss function:
loss_rd() = sum(abs2, x-1 for x in predict_rd())
I wanted to train the model for oscillating populations with the following code
using DifferentialEquations, Plots, Flux, DiffEqFlux
function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ= p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
p = [1.5, 1.0, 3.0, 1.0]
prob = ODEProblem(lotka_volterra,u0,tspan,p)
sol = solve(prob, Tsit5(), saveat=0.1)
A1 = sol[1,:]
A2 = sol[2,:]
t = 0:0.1:10.0
p = [4.0,1.0,2.0,0.4]
params = Flux.params(p)
function predict_rd()
solve(prob, Tsit5(), p=p, saveat=0.1)#[1,:]
end
#loss_rd() = sum(abs2, x-1 for x in predict_rd())
loss_rd() = sum(abs2, predict_rd() .- sol)
data = Iterators.repeated((), 1000)
opt = ADAM(0.1)
cb = function ()
display(hcat( loss_rd(), params))
scatter(t, A1, color=[1], label = "conejos")
scatter!(t, A2, color=[2], label = "lobos")
display(plot!(solve(remake(prob, p=p), Tsit5(),saveat=0.1), ylim=(0,7)))
end
cb()
Flux.train!(loss_rd, params, data, opt, cb = cb)
And it works nice. The problem is when I try the same with a DDE, the code is almost identical :
using DifferentialEquations, Plots, Flux, DiffEqFlux, DiffEqSensitivity
function delay_lotka_volterra(du,u,h,p,t)
x, y = u
α, β, δ, γ= p
du[1] = dx = (α - β*y)*h(p,t-0.1)[1]
du[2] = dy = (-γ + δ*x)*y
end
u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
p = [1.5, 1.0, 3.0, 1.0]
h(p, t) = ones(eltype(p),2)
prob = DDEProblem(delay_lotka_volterra,u0,h,tspan,constant_lags=[0.1])
data_sol = solve(prob,MethodOfSteps(Tsit5()), p=p, sensealg=TrackerAdjoint(), saveat = 0.1)
A1 = data_sol[1,:]
A2 = data_sol[2,:]
t = 0:0.1:10.0
p = [2.2,1.0,2.0,0.4]
params = Flux.params(p)
function predict_rd_dde()
solve(prob, MethodOfSteps(Tsit5()), p=p, sensealg=TrackerAdjoint(), saveat=0.1)#[1,:]
end
#loss_rd_dde() = sum(abs2, x-1 for x in predict_rd_dde())
loss_rd_dde() = sum(abs2, predict_rd_dde() .- data_sol)
data = Iterators.repeated((), 100)
opt = ADAM(0.1)
cb = function ()
display(loss_rd_dde())
scatter(t, A1, color=[1], label = "conejos")
scatter!(t, A2, color=[2], label = "lobos")
display(plot!(solve(remake(prob, p=p), MethodOfSteps(Tsit5()),saveat=0.1), ylim=(0,6)))
end
cb()
Flux.train!(loss_rd_dde, params, data, opt, cb = cb)
But it gives the following error:
ERROR: LoadError: MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
Float64(::T) where T<:Number at boot.jl:716
Float64(::Irrational{:sqrtπ}) at irrationals.jl:189
...
Stacktrace:
[1] convert(::Type{Float64}, ::Tracker.TrackedReal{Float64}) at ./number.jl:7
[2] setindex!(::Array{Float64,1}, ::Tracker.TrackedReal{Float64}, ::Int64) at ./array.jl:847
[3] (::Tracker.var"#374#376"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},TrackedArray{…,Array{Float64,1}},Tuple{Int64}})(::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/lib/array.jl:105
[4] back_(::Tracker.Grads, ::Tracker.Call{Tracker.var"#374#376"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},TrackedArray{…,Array{Float64,1}},Tuple{Int64}},Tuple{Tracker.Tracked{Array{Float64,1}},Nothing}}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:110
[5] back(::Tracker.Grads, ::Tracker.Tracked{Float64}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:125
[6] #16 at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:113 [inlined]
[7] foreach at ./abstractarray.jl:2010 [inlined]
[8] back_(::Tracker.Grads, ::Tracker.Call{Tracker.var"#201#202",Tuple{Tracker.Tracked{Float64}}}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:113
... (the last 4 lines are repeated 6 more times)
[33] back(::Tracker.Grads, ::Tracker.Tracked{Float64}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:125
[34] (::Tracker.var"#369#370"{Tracker.Grads})(::Tracker.Tracked{Float64}, ::Tracker.TrackedReal{Tracker.TrackedReal{Float64}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/lib/real.jl:156
[35] foreach(::Function, ::Array{Tracker.Tracked{Float64},2}, ::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at ./abstractarray.jl:2010
[36] back_(::Tracker.Grads, ::Tracker.Call{typeof(Tracker.collect),Tuple{Array{Tracker.Tracked{Float64},2}}}, ::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/lib/real.jl:156
[37] back(::Tracker.Grads, ::Tracker.Tracked{Array{Float64,2}}, ::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:125
[38] #18 at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:140 [inlined]
[39] (::Tracker.var"#21#23"{Tracker.var"#18#19"{Tracker.Params,TrackedArray{…,Array{Float64,2}}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Tracker/OuWUu/src/back.jl:149
[40] (::DiffEqSensitivity.var"#tracker_adjoint_backpass#164"{Tuple{},Tracker.var"#21#23"{Tracker.var"#18#19"{Tracker.Params,TrackedArray{…,Array{Float64,2}}}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/DiffEqSensitivity/ZdaQE/src/local_sensitivity/concrete_solve.jl:298
[41] #694#back at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65 [inlined]
[42] #175 at /home/qualium/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
[43] (::Zygote.var"#359#back#177"{Zygote.var"#175#176"{DiffEqBase.var"#694#back#474"{DiffEqSensitivity.var"#tracker_adjoint_backpass#164"{Tuple{},Tracker.var"#21#23"{Tracker.var"#18#19"{Tracker.Params,TrackedArray{…,Array{Float64,2}}}}}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[44] #solve#460 at /home/qualium/.julia/packages/DiffEqBase/V7P18/src/solve.jl:102 [inlined]
[45] (::typeof(∂(#solve#460)))(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[46] (::Zygote.var"#175#176"{typeof(∂(#solve#460)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182
[47] (::Zygote.var"#359#back#177"{Zygote.var"#175#176"{typeof(∂(#solve#460)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[48] (::typeof(∂(solve##kw)))(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[49] predict_rd_dde at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:25 [inlined]
[50] (::typeof(∂(predict_rd_dde)))(::TrackedArray{…,VectorOfArray{Tracker.TrackedReal{Float64},2,Array{Array{Tracker.TrackedReal{Float64},1},1}}}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[51] loss_rd_dde at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:28 [inlined]
[52] (::typeof(∂(loss_rd_dde)))(::Tracker.TrackedReal{Float64}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
[53] #175 at /home/qualium/.julia/packages/Zygote/1GXzF/src/lib/lib.jl:182 [inlined]
[54] #359#back at /home/qualium/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[55] #17 at /home/qualium/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:89 [inlined]
[56] (::Zygote.var"#50#51"{Zygote.Params,Zygote.Context,typeof(∂(#17))})(::Tracker.TrackedReal{Float64}) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:177
[57] gradient(::Function, ::Zygote.Params) at /home/qualium/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:54
[58] macro expansion at /home/qualium/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:88 [inlined]
[59] macro expansion at /home/qualium/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
[60] train!(::typeof(loss_rd_dde), ::Zygote.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM; cb::var"#15#16") at /home/qualium/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
[61] top-level scope at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:39
[62] include(::String) at ./client.jl:457
in expression starting at /home/qualium/Documentos/julia_things/lotka_volterra_DDE_neural.jl:39
I’m almost begginer in Julia so I don’t understand very well the error messages, but I checked typeof the variables and they are the same type, numbers are all Float64 etc.
I suppose it has to do with the new pkg DiffEqSensitivity or how it integrates using TrackerAdjoint, but I can’t make sense of any of this.
Any thoughts? Thank you