Basically, i want to fit the matrix of an ODE of the form
I start by defining the matrxi:
using Plots, Flux, Zygote, Flux, DiffEqFlux, DiffEqSensitivity, Plots
A = [ 0 0 0 0;
1 0 5 0;
0 5 0 0;
0 0 1 0];
A = (A' ./ (sum(A,dims = 1) .+ .05)')' - I
4×4 Matrix{Float64}:
-1.0 0.0 0.0 0.0
0.952381 -1.0 0.826446 0.0
0.0 0.990099 -1.0 0.0
0.0 0.0 0.165289 -1.0
Then define the problem with f(u,t,p) = p \cdot u.
function f(u,p,t)
du = p*u
end
u0 = ones(4)
tspan = (0.0, 50.0)
p = A
prob = ODEProblem(f, u0, tspan, p)
And solve it with the initial conditions u0 = ones(4)
:
sol = solve(prob, Tsit5())
plot(sol)
I save the solution for this A
:
real_sol = sol(1:10);
And use it to define the loss function as the difference between this solution and the one for a different p/A
function loss(u0)
sol = solve(prob,Tsit5(),
sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()),
p=p,saveat=0.1)
sum(abs2.(sol(1:10) .- real_sol))
end
Try the loss function if it works:
Correct A gives small loss:
loss(A)
3.7152693264590607e-31
Loss seems to works (scales with the ‘difference’ between correct A
and given parameter:
loss(A .+ rand(4,4)*.1)
36.981767062772995
loss(A .+ rand(4,4)*.01)
0.18012156062961843
However, trying to optimize it gives me the following error:
p = A .+ rand(4,4)*.05
result_ode = DiffEqFlux.sciml_train(loss,p,
ADAM(0.1),
maxiters = 2)
MethodError: no method matching +(::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, ::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
+(::ChainRulesCore.AbstractThunk, ::Any) at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:160
+(::ChainRulesCore.Composite{P, T} where T, ::P) where P at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:184
...
Stacktrace:
[1] add_sum(x::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, y::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})
@ Base ./reduce.jl:24
[2] _mapreduce
@ ./reduce.jl:408 [inlined]
[3] _mapreduce_dim
@ ./reducedim.jl:318 [inlined]
[4] #mapreduce#672
@ ./reducedim.jl:310 [inlined]
[5] mapreduce
@ ./reducedim.jl:310 [inlined]
[6] #_sum#682
@ ./reducedim.jl:878 [inlined]
[7] _sum
@ ./reducedim.jl:878 [inlined]
[8] #_sum#681
@ ./reducedim.jl:877 [inlined]
[9] _sum
@ ./reducedim.jl:877 [inlined]
[10] #sum#679
@ ./reducedim.jl:873 [inlined]
[11] sum
@ ./reducedim.jl:873 [inlined]
[12] #adjoint#598
@ ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:263 [inlined]
[13] adjoint
@ ./none:0 [inlined]
[14] _pullback(__context__::Zygote.Context, 533::typeof(sum), xs::Vector{Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57
[15] _pullback
@ ./In[24]:6 [inlined]
[16] _pullback(ctx::Zygote.Context, f::typeof(loss), args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[17] _pullback
@ ~/.julia/packages/DiffEqFlux/alPQ3/src/train.jl:3 [inlined]
[18] _pullback(::Zygote.Context, ::DiffEqFlux.var"#69#70"{typeof(loss)}, ::Matrix{Float64}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[19] _apply
@ ./boot.jl:804 [inlined]
[20] adjoint
@ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]
[21] _pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
[22] _pullback
@ ~/.julia/packages/SciMLBase/EFFG1/src/problems/basic_problems.jl:107 [inlined]
[23] _pullback(::Zygote.Context, ::OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Matrix{Float64}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[24] _apply(::Function, ::Vararg{Any, N} where N)
@ Core ./boot.jl:804
[25] adjoint
@ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]
[26] _pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
[27] _pullback
@ ~/.julia/packages/SciMLBase/EFFG1/src/problems/basic_problems.jl:107 [inlined]
[28] _pullback(::Zygote.Context, ::OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#156#166", Nothing, Nothing, Nothing}, ::Matrix{Float64}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[29] _apply
@ ./boot.jl:804 [inlined]
[30] adjoint
@ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]
[31] _pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
[32] _pullback
@ ~/.julia/packages/GalacticOptim/Ts0Bu/src/solve.jl:91 [inlined]
[33] _pullback(::Zygote.Context, ::GalacticOptim.var"#8#13"{OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#156#166", Nothing, Nothing, Nothing}, Matrix{Float64}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:maxiters,), Tuple{Int64}}}}, Matrix{Float64}, GalacticOptim.NullData})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[34] pullback(f::Function, ps::Params)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:247
[35] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58
[36] __solve(prob::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#156#166", Nothing, Nothing, Nothing}, Matrix{Float64}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:maxiters,), Tuple{Int64}}}}, opt::ADAM, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ GalacticOptim ~/.julia/packages/GalacticOptim/Ts0Bu/src/solve.jl:90
[37] #solve#468
@ ~/.julia/packages/SciMLBase/EFFG1/src/solve.jl:3 [inlined]
[38] sciml_train(::typeof(loss), ::Matrix{Float64}, ::ADAM, ::GalacticOptim.AutoZygote; lower_bounds::Nothing, upper_bounds::Nothing, kwargs::Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:maxiters,), Tuple{Int64}}})
@ DiffEqFlux ~/.julia/packages/DiffEqFlux/alPQ3/src/train.jl:6
[39] top-level scope
@ In[28]:2
[40] eval
@ ./boot.jl:360 [inlined]
[41] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1094
Cant even calculate the gradient:
du1 = Zygote.gradient(loss,p)
MethodError: no method matching +(::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, ::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
+(::ChainRulesCore.AbstractThunk, ::Any) at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:160
+(::ChainRulesCore.Composite{P, T} where T, ::P) where P at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:184
...
Stacktrace:
[1] add_sum(x::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, y::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})
@ Base ./reduce.jl:24
[2] _mapreduce
@ ./reduce.jl:408 [inlined]
[3] _mapreduce_dim
@ ./reducedim.jl:318 [inlined]
[4] #mapreduce#672
@ ./reducedim.jl:310 [inlined]
[5] mapreduce
@ ./reducedim.jl:310 [inlined]
[6] #_sum#682
@ ./reducedim.jl:878 [inlined]
[7] _sum
@ ./reducedim.jl:878 [inlined]
[8] #_sum#681
@ ./reducedim.jl:877 [inlined]
[9] _sum
@ ./reducedim.jl:877 [inlined]
[10] #sum#679
@ ./reducedim.jl:873 [inlined]
[11] sum
@ ./reducedim.jl:873 [inlined]
[12] #adjoint#598
@ ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:263 [inlined]
[13] adjoint
@ ./none:0 [inlined]
[14] _pullback(__context__::Zygote.Context, 533::typeof(sum), xs::Vector{Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57
[15] _pullback
@ ./In[24]:6 [inlined]
[16] _pullback(ctx::Zygote.Context, f::typeof(loss), args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[17] _pullback(f::Function, args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:34
[18] pullback(f::Function, args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:40
[19] gradient(f::Function, args::Matrix{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58
[20] top-level scope
@ In[34]:1
[21] eval
@ ./boot.jl:360 [inlined]
[22] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1094