How to solve an ODEProblem with u0 in the format of vector{vector} and be able to use it for Zygote differentiation

Hello, I am currently facing an ode problem. The format of u0 in this problem is vector{vector}, for example [[1,2], [1,2,3], [1]], the particularity of this u0 is that the length of each element is different, and I cannot easily convert it into a matrix as u0. I first directly use vector{vector} as u0, the code is as follows:


function foo_1(u, p, t)
    p1, p2 = p[1], p[2]
    [u[1] .* p1, u[2] .* p2]
end

prob = ODEProblem(foo_1, [[2.0, 3.0], [2.1, 3.2, 4.2]], (0, 5), [1.01, 0.99])
sol = solve(prob, Tsit5(), saveat=1.0)

The results of running this code suggest that I use RecursiveArrayTools.jl or ComponentArrays.jl as the format of u0, so I refactored the code as follows:

# based on RecursiveArrayTools
function foo_2(u, p, t)
    p1, p2 = p[1], p[2]
    VectorOfArray([u[1] .* p1, u[2] ./ p2])
end

prob = ODEProblem(foo_2, VectorOfArray([[2.0, 3.0], [2.1, 3.2, 4.2]]), (0, 5), [1.01, 0.99])
sol = solve(prob, Tsit5(), saveat=1.0)

# based on ComponentArrays
function foo_3(u, p, t)
    p1, p2 = p[1], p[2]
    ComponentVector(a=u.a .* p1, b=u.b ./ p2)
end

prob = ODEProblem(f2, ComponentVector(a=[2.0, 3.0], b=[2.1, 3.2, 4.2]), (0, 5), [1.01, 0.99])
sol = solve(prob, Tsit5(), saveat=1.0)

Both methods can solve the ode problem (the results seem to be different), and then I try to use zygote to find the gradients calculated by these two methods:


function g2(p)
    function foo_2(u, p, t)
        p1, p2 = p[1], p[2]
        VectorOfArray([u[1] .* p1, u[2] ./ p2])
    end

    prob = ODEProblem(foo_2, VectorOfArray([[2.0, 3.0], [2.1, 3.2, 4.2]]), (0, 5), p)
    sol = solve(prob, Tsit5(), saveat=1.0)
    sum(sol)
end

gradient(g2, [1.01, 0.99])

The way you use RecursiveArrayTools will cause an error:

ERROR: Can only convert non-ragged VectorOfArray to Array
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] convert(::Type{Array}, VA::VectorOfArray{Float64, 2, Vector{Vector{Float64}}})
    @ RecursiveArrayTools D:\Julia\Julia-1.10.4\packages\packages\RecursiveArrayTools\xGKIm\src\vector_of_array.jl:760
  [3] vec
    @ D:\Julia\Julia-1.10.4\packages\packages\RecursiveArrayTools\xGKIm\src\vector_of_array.jl:756 [inlined]
  [4] (::SciMLSensitivity.var"#330#339"{…})()
    @ SciMLSensitivity D:\Julia\Julia-1.10.4\packages\packages\SciMLSensitivity\PstNN\src\concrete_solve.jl:953
  [5] unthunk
    @ D:\Julia\Julia-1.10.4\packages\packages\ChainRulesCore\I1EbV\src\tangent_types\thunks.jl:204 [inlined]
  [6] wrap_chainrules_output
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:110 [inlined]
  [7] map
    @ .\tuple.jl:293 [inlined]
  [8] map (repeats 3 times)
    @ .\tuple.jl:294 [inlined]
  [9] wrap_chainrules_output
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:111 [inlined]
 [10] ZBack
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211 [inlined]
 [11] kw_zpullback
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\chainrules.jl:237 [inlined]
 [12] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [13] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [14] #solve#51
    @ D:\Julia\Julia-1.10.4\packages\packages\DiffEqBase\V6SCE\src\solve.jl:1003 [inlined]
 [15] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [16] #291
    @ D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [17] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [18] solve
    @ D:\Julia\Julia-1.10.4\packages\packages\DiffEqBase\V6SCE\src\solve.jl:993 [inlined]
 [19] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [20] g2
    @ e:\JlCode\LumpedHydro\test\tmp_test.jl:42 [inlined]
 [21] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [22] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:91
 [23] gradient(f::Function, args::Vector{Float64})
    @ Zygote D:\Julia\Julia-1.10.4\packages\packages\Zygote\nsBv0\src\compiler\interface.jl:148
 [24] top-level scope
    @ e:\JlCode\LumpedHydro\test\tmp_test.jl:46
Some type information was truncated. Use `show(err)` to see complete types.

Then using ComponentArrays seems to work:

function g3(p)
    function foo_3(u, p, t)
        p1, p2 = p[1], p[2]
        ComponentVector(a=u.a .* p1, b=u.b ./ p2)
    end
    
    prob = ODEProblem(foo_3, ComponentVector(a=[2.0, 3.0], b=[2.1, 3.2, 4.2]), (0, 5), p)
    sol = solve(prob, Tsit5(), saveat=1.0)
    sum(sol)
end

gradient(g3, [1.01,0.99])

The result is: ([5436.3132778556665, -10543.639885459668],)

Vector{Vector} does not satisfy the documented interface so it won’t work. That’s a given. ComponentArrays is probably the way to go here. I don’t quite get the question if there is one.