# Optimizing matrix-valued parameters of an ODE

Basically, i want to fit the matrix of an ODE of the form

\dot{u}(t) = p\cdot u(t), \quad p \in Mat_{n\times n}(\mathbb{R}), u \in \mathbb{R}^n

I start by defining the matrxi:

using Plots, Flux, Zygote, Flux, DiffEqFlux, DiffEqSensitivity, Plots

A = [  0   0   0   0;
1   0   5   0;
0   5   0   0;
0   0   1   0];
A = (A' ./ (sum(A,dims = 1) .+ .05)')' - I

4×4 Matrix{Float64}:
-1.0        0.0        0.0        0.0
0.952381  -1.0        0.826446   0.0
0.0        0.990099  -1.0        0.0
0.0        0.0        0.165289  -1.0


Then define the problem with f(u,t,p) = p \cdot u.

function f(u,p,t)
du = p*u
end

u0 = ones(4)

tspan = (0.0, 50.0)
p = A

prob = ODEProblem(f, u0, tspan, p)


And solve it with the initial conditions u0 = ones(4):

sol = solve(prob, Tsit5())
plot(sol)


I save the solution for this A:

real_sol = sol(1:10);


And use it to define the loss function as the difference between this solution and the one for a different p/A

function loss(u0)
sol = solve(prob,Tsit5(),
p=p,saveat=0.1)

sum(abs2.(sol(1:10) .- real_sol))
end


Try the loss function if it works:

Correct A gives small loss:

loss(A)

3.7152693264590607e-31


Loss seems to works (scales with the ‘difference’ between correct A and given parameter:

loss(A .+ rand(4,4)*.1)

36.981767062772995

loss(A .+ rand(4,4)*.01)

0.18012156062961843


However, trying to optimize it gives me the following error:

p = A .+ rand(4,4)*.05
result_ode = DiffEqFlux.sciml_train(loss,p,
maxiters = 2)

MethodError: no method matching +(::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, ::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
+(::ChainRulesCore.AbstractThunk, ::Any) at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:160
+(::ChainRulesCore.Composite{P, T} where T, ::P) where P at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:184
...

Stacktrace:

@ Base ./reduce.jl:24

[2] _mapreduce

@ ./reduce.jl:408 [inlined]

[3] _mapreduce_dim

@ ./reducedim.jl:318 [inlined]

[4] #mapreduce#672

@ ./reducedim.jl:310 [inlined]

[5] mapreduce

@ ./reducedim.jl:310 [inlined]

[6] #_sum#682

@ ./reducedim.jl:878 [inlined]

[7] _sum

@ ./reducedim.jl:878 [inlined]

[8] #_sum#681

@ ./reducedim.jl:877 [inlined]

[9] _sum

@ ./reducedim.jl:877 [inlined]

[10] #sum#679

@ ./reducedim.jl:873 [inlined]

[11] sum

@ ./reducedim.jl:873 [inlined]

@ ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:263 [inlined]

@ ./none:0 [inlined]

[14] _pullback(__context__::Zygote.Context, 533::typeof(sum), xs::Vector{Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}})

[15] _pullback

@ ./In[24]:6 [inlined]

[16] _pullback(ctx::Zygote.Context, f::typeof(loss), args::Matrix{Float64})

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

[17] _pullback

@ ~/.julia/packages/DiffEqFlux/alPQ3/src/train.jl:3 [inlined]

[18] _pullback(::Zygote.Context, ::DiffEqFlux.var"#69#70"{typeof(loss)}, ::Matrix{Float64}, ::SciMLBase.NullParameters)

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

[19] _apply

@ ./boot.jl:804 [inlined]

@ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]

[21] _pullback

[22] _pullback

@ ~/.julia/packages/SciMLBase/EFFG1/src/problems/basic_problems.jl:107 [inlined]

[23] _pullback(::Zygote.Context, ::OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Matrix{Float64}, ::SciMLBase.NullParameters)

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

[24] _apply(::Function, ::Vararg{Any, N} where N)

@ Core ./boot.jl:804

@ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]

[26] _pullback

[27] _pullback

@ ~/.julia/packages/SciMLBase/EFFG1/src/problems/basic_problems.jl:107 [inlined]

[28] _pullback(::Zygote.Context, ::OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#156#166", Nothing, Nothing, Nothing}, ::Matrix{Float64}, ::SciMLBase.NullParameters)

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

[29] _apply

@ ./boot.jl:804 [inlined]

@ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]

[31] _pullback

[32] _pullback

@ ~/.julia/packages/GalacticOptim/Ts0Bu/src/solve.jl:91 [inlined]

[33] _pullback(::Zygote.Context, ::GalacticOptim.var"#8#13"{OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#156#166", Nothing, Nothing, Nothing}, Matrix{Float64}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:maxiters,), Tuple{Int64}}}}, Matrix{Float64}, GalacticOptim.NullData})

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

[34] pullback(f::Function, ps::Params)

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:247

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58

[36] __solve(prob::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#156#166", Nothing, Nothing, Nothing}, Matrix{Float64}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:maxiters,), Tuple{Int64}}}}, opt::ADAM, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})

@ GalacticOptim ~/.julia/packages/GalacticOptim/Ts0Bu/src/solve.jl:90

[37] #solve#468

@ ~/.julia/packages/SciMLBase/EFFG1/src/solve.jl:3 [inlined]

[38] sciml_train(::typeof(loss), ::Matrix{Float64}, ::ADAM, ::GalacticOptim.AutoZygote; lower_bounds::Nothing, upper_bounds::Nothing, kwargs::Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:maxiters,), Tuple{Int64}}})

@ DiffEqFlux ~/.julia/packages/DiffEqFlux/alPQ3/src/train.jl:6

[39] top-level scope

@ In[28]:2

[40] eval

@ ./boot.jl:360 [inlined]

[41] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)



du1 = Zygote.gradient(loss,p)

MethodError: no method matching +(::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, ::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
+(::ChainRulesCore.AbstractThunk, ::Any) at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:160
+(::ChainRulesCore.Composite{P, T} where T, ::P) where P at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:184
...

Stacktrace:

@ Base ./reduce.jl:24

[2] _mapreduce

@ ./reduce.jl:408 [inlined]

[3] _mapreduce_dim

@ ./reducedim.jl:318 [inlined]

[4] #mapreduce#672

@ ./reducedim.jl:310 [inlined]

[5] mapreduce

@ ./reducedim.jl:310 [inlined]

[6] #_sum#682

@ ./reducedim.jl:878 [inlined]

[7] _sum

@ ./reducedim.jl:878 [inlined]

[8] #_sum#681

@ ./reducedim.jl:877 [inlined]

[9] _sum

@ ./reducedim.jl:877 [inlined]

[10] #sum#679

@ ./reducedim.jl:873 [inlined]

[11] sum

@ ./reducedim.jl:873 [inlined]

@ ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:263 [inlined]

@ ./none:0 [inlined]

[14] _pullback(__context__::Zygote.Context, 533::typeof(sum), xs::Vector{Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}})

[15] _pullback

@ ./In[24]:6 [inlined]

[16] _pullback(ctx::Zygote.Context, f::typeof(loss), args::Matrix{Float64})

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

[17] _pullback(f::Function, args::Matrix{Float64})

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:34

[18] pullback(f::Function, args::Matrix{Float64})

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:40

@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58

[20] top-level scope

@ In[34]:1

[21] eval

@ ./boot.jl:360 [inlined]

[22] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)

@ Base ./loading.jl:1094

You don’t need to interpolate in the loss function if you already did saveat. The issue is that it didn’t align with real_sol though, so I don’t quite get what you were doing there. But here’s an aligned version:

using Plots, Flux, Zygote, Flux, DiffEqFlux, DiffEqSensitivity, Plots
using LinearAlgebra, OrdinaryDiffEq
A = [  0   0   0   0;
1   0   5   0;
0   5   0   0;
0   0   1   0];
A = (A' ./ (sum(A,dims = 1) .+ .05)')' - I

function f(u,p,t)
du = p*u
end

u0 = ones(4)

tspan = (0.0, 10.0)
p = A

prob = ODEProblem(f, u0, tspan, p)

sol = solve(prob, Tsit5(), saveat = 1.0)
plot(sol)
real_sol = Array(sol);
function loss(p)
sol = solve(prob,Tsit5(),
p=p,saveat=1.0)
sum(abs2,sol .- real_sol)
end

p = A .+ rand(4,4)*.05
loss(p)
result_ode = DiffEqFlux.sciml_train(loss,p,

saveat can take a vector.