Optimizing matrix-valued parameters of an ODE

Basically, i want to fit the matrix of an ODE of the form

\dot{u}(t) = p\cdot u(t), \quad p \in Mat_{n\times n}(\mathbb{R}), u \in \mathbb{R}^n

I start by defining the matrxi:

using Plots, Flux, Zygote, Flux, DiffEqFlux, DiffEqSensitivity, Plots

A = [  0   0   0   0;
       1   0   5   0;
       0   5   0   0;
       0   0   1   0];
A = (A' ./ (sum(A,dims = 1) .+ .05)')' - I
4×4 Matrix{Float64}:
 -1.0        0.0        0.0        0.0
  0.952381  -1.0        0.826446   0.0
  0.0        0.990099  -1.0        0.0
  0.0        0.0        0.165289  -1.0

Then define the problem with f(u,t,p) = p \cdot u.

function f(u,p,t)
    du = p*u
end

u0 = ones(4)

tspan = (0.0, 50.0)
p = A

prob = ODEProblem(f, u0, tspan, p)

And solve it with the initial conditions u0 = ones(4):

sol = solve(prob, Tsit5())
plot(sol)

I save the solution for this A:

real_sol = sol(1:10);

And use it to define the loss function as the difference between this solution and the one for a different p/A

function loss(u0) 
    sol = solve(prob,Tsit5(),
            sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()),
            p=p,saveat=0.1)
    
    sum(abs2.(sol(1:10) .- real_sol))
end

Try the loss function if it works:

Correct A gives small loss:

loss(A)
3.7152693264590607e-31

Loss seems to works (scales with the ‘difference’ between correct A and given parameter:

loss(A .+ rand(4,4)*.1)
36.981767062772995
loss(A .+ rand(4,4)*.01)
0.18012156062961843

However, trying to optimize it gives me the following error:

p = A .+ rand(4,4)*.05
result_ode = DiffEqFlux.sciml_train(loss,p,
                                    ADAM(0.1),
                                    maxiters = 2)
MethodError: no method matching +(::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, ::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
  +(::ChainRulesCore.AbstractThunk, ::Any) at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:160
  +(::ChainRulesCore.Composite{P, T} where T, ::P) where P at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:184
  ...



Stacktrace:

  [1] add_sum(x::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, y::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})

    @ Base ./reduce.jl:24

  [2] _mapreduce

    @ ./reduce.jl:408 [inlined]

  [3] _mapreduce_dim

    @ ./reducedim.jl:318 [inlined]

  [4] #mapreduce#672

    @ ./reducedim.jl:310 [inlined]

  [5] mapreduce

    @ ./reducedim.jl:310 [inlined]

  [6] #_sum#682

    @ ./reducedim.jl:878 [inlined]

  [7] _sum

    @ ./reducedim.jl:878 [inlined]

  [8] #_sum#681

    @ ./reducedim.jl:877 [inlined]

  [9] _sum

    @ ./reducedim.jl:877 [inlined]

 [10] #sum#679

    @ ./reducedim.jl:873 [inlined]

 [11] sum

    @ ./reducedim.jl:873 [inlined]

 [12] #adjoint#598

    @ ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:263 [inlined]

 [13] adjoint

    @ ./none:0 [inlined]

 [14] _pullback(__context__::Zygote.Context, 533::typeof(sum), xs::Vector{Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}})

    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57

 [15] _pullback

    @ ./In[24]:6 [inlined]

 [16] _pullback(ctx::Zygote.Context, f::typeof(loss), args::Matrix{Float64})

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

 [17] _pullback

    @ ~/.julia/packages/DiffEqFlux/alPQ3/src/train.jl:3 [inlined]

 [18] _pullback(::Zygote.Context, ::DiffEqFlux.var"#69#70"{typeof(loss)}, ::Matrix{Float64}, ::SciMLBase.NullParameters)

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

 [19] _apply

    @ ./boot.jl:804 [inlined]

 [20] adjoint

    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]

 [21] _pullback

    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]

 [22] _pullback

    @ ~/.julia/packages/SciMLBase/EFFG1/src/problems/basic_problems.jl:107 [inlined]

 [23] _pullback(::Zygote.Context, ::OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Matrix{Float64}, ::SciMLBase.NullParameters)

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

 [24] _apply(::Function, ::Vararg{Any, N} where N)

    @ Core ./boot.jl:804

 [25] adjoint

    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]

 [26] _pullback

    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]

 [27] _pullback

    @ ~/.julia/packages/SciMLBase/EFFG1/src/problems/basic_problems.jl:107 [inlined]

 [28] _pullback(::Zygote.Context, ::OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#156#166", Nothing, Nothing, Nothing}, ::Matrix{Float64}, ::SciMLBase.NullParameters)

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

 [29] _apply

    @ ./boot.jl:804 [inlined]

 [30] adjoint

    @ ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:191 [inlined]

 [31] _pullback

    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]

 [32] _pullback

    @ ~/.julia/packages/GalacticOptim/Ts0Bu/src/solve.jl:91 [inlined]

 [33] _pullback(::Zygote.Context, ::GalacticOptim.var"#8#13"{OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#156#166", Nothing, Nothing, Nothing}, Matrix{Float64}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:maxiters,), Tuple{Int64}}}}, Matrix{Float64}, GalacticOptim.NullData})

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

 [34] pullback(f::Function, ps::Params)

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:247

 [35] gradient(f::Function, args::Params)

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58

 [36] __solve(prob::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#148#158"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#151#161"{GalacticOptim.var"#147#157"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#69#70"{typeof(loss)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#156#166", Nothing, Nothing, Nothing}, Matrix{Float64}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:maxiters,), Tuple{Int64}}}}, opt::ADAM, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})

    @ GalacticOptim ~/.julia/packages/GalacticOptim/Ts0Bu/src/solve.jl:90

 [37] #solve#468

    @ ~/.julia/packages/SciMLBase/EFFG1/src/solve.jl:3 [inlined]

 [38] sciml_train(::typeof(loss), ::Matrix{Float64}, ::ADAM, ::GalacticOptim.AutoZygote; lower_bounds::Nothing, upper_bounds::Nothing, kwargs::Base.Iterators.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:maxiters,), Tuple{Int64}}})

    @ DiffEqFlux ~/.julia/packages/DiffEqFlux/alPQ3/src/train.jl:6

 [39] top-level scope

    @ In[28]:2

 [40] eval

    @ ./boot.jl:360 [inlined]

 [41] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)

    @ Base ./loading.jl:1094

Cant even calculate the gradient:

du1 = Zygote.gradient(loss,p)
MethodError: no method matching +(::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, ::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
  +(::ChainRulesCore.AbstractThunk, ::Any) at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:160
  +(::ChainRulesCore.Composite{P, T} where T, ::P) where P at /Users/hurt0jan/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:184
  ...



Stacktrace:

  [1] add_sum(x::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}, y::Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}})

    @ Base ./reduce.jl:24

  [2] _mapreduce

    @ ./reduce.jl:408 [inlined]

  [3] _mapreduce_dim

    @ ./reducedim.jl:318 [inlined]

  [4] #mapreduce#672

    @ ./reducedim.jl:310 [inlined]

  [5] mapreduce

    @ ./reducedim.jl:310 [inlined]

  [6] #_sum#682

    @ ./reducedim.jl:878 [inlined]

  [7] _sum

    @ ./reducedim.jl:878 [inlined]

  [8] #_sum#681

    @ ./reducedim.jl:877 [inlined]

  [9] _sum

    @ ./reducedim.jl:877 [inlined]

 [10] #sum#679

    @ ./reducedim.jl:873 [inlined]

 [11] sum

    @ ./reducedim.jl:873 [inlined]

 [12] #adjoint#598

    @ ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:263 [inlined]

 [13] adjoint

    @ ./none:0 [inlined]

 [14] _pullback(__context__::Zygote.Context, 533::typeof(sum), xs::Vector{Tuple{Float64, Zygote.var"#1844#back#248"{Zygote.var"#246#247"{Float64}}}})

    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57

 [15] _pullback

    @ ./In[24]:6 [inlined]

 [16] _pullback(ctx::Zygote.Context, f::typeof(loss), args::Matrix{Float64})

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0

 [17] _pullback(f::Function, args::Matrix{Float64})

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:34

 [18] pullback(f::Function, args::Matrix{Float64})

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:40

 [19] gradient(f::Function, args::Matrix{Float64})

    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58

 [20] top-level scope

    @ In[34]:1

 [21] eval

    @ ./boot.jl:360 [inlined]

 [22] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)

    @ Base ./loading.jl:1094

You don’t need to interpolate in the loss function if you already did saveat. The issue is that it didn’t align with real_sol though, so I don’t quite get what you were doing there. But here’s an aligned version:

using Plots, Flux, Zygote, Flux, DiffEqFlux, DiffEqSensitivity, Plots
using LinearAlgebra, OrdinaryDiffEq
A = [  0   0   0   0;
       1   0   5   0;
       0   5   0   0;
       0   0   1   0];
A = (A' ./ (sum(A,dims = 1) .+ .05)')' - I

function f(u,p,t)
    du = p*u
end

u0 = ones(4)

tspan = (0.0, 10.0)
p = A

prob = ODEProblem(f, u0, tspan, p)

sol = solve(prob, Tsit5(), saveat = 1.0)
plot(sol)
real_sol = Array(sol);
function loss(p)
    sol = solve(prob,Tsit5(),
            sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP()),
            p=p,saveat=1.0)
    sum(abs2,sol .- real_sol)
end

p = A .+ rand(4,4)*.05
loss(p)
result_ode = DiffEqFlux.sciml_train(loss,p,
                                    ADAM(0.1),
                                    maxiters = 2)

du1 = Zygote.gradient(loss,p)

saveat can take a vector.