Nested struct as state variable in DifferentialEquations

I’m learning to use the DifferentialEquations ecosystem, and my goal is to simulate dynamic systems in which the state variable is a custom struct.

In the simplest case, my state variable is just the x and y coordinates of a particle in space, as in the following example:

``````using DifferentialEquations

# State variable: A point in Cartesian space
mutable struct Point2D{T}
x::T
y::T
end

function f!(du::Point2D, u::Point2D, _, __)
du.x = u.x + 0.1u.y
du.y = -0.2u.x + u.y
end

let u0 = Point2D(25.0, 30.0)
prob = ODEProblem(f!, u0, (0, 100))
sol = solve(prob, AutoTsit5(Rosenbrock23()))
end
``````

Now, if you run this example, you get the following error, because DifferentialEquations doesn’t know how to initialize `du` inside of the solver:

``````ERROR: MethodError: no method matching oneunit(::Type{Any})
Closest candidates are:
oneunit(::Type{Union{Missing, T}}) where T at missing.jl:105
oneunit(::Type{T}) where T at number.jl:358
oneunit(::T) where T at number.jl:357
``````

The solution, as pointed out by @lmiq in this other thread of mine, is to define `Point2D` as a subtype of the abstract type `FieldVector{2, T}` from the `StaticArrays` package. Then it automatically inherits all the algebra necessary for the DE solvers to run. The following example works properly:

Same code but with `using StaticArrays` and `Point2D{T} <: FieldVector{2, T}`
``````using DifferentialEquations
using StaticArrays

# State variable: A point in Cartesian space
mutable struct Point2D{T} <: FieldVector{2, T}
x::T
y::T
end

function f!(du::Point2D, u::Point2D, _, __)
du.x = u.x + 0.1u.y
du.y = -0.2u.x + u.y
end

let u0 = Point2D(25.0, 30.0)
prob = ODEProblem(f!, u0, (0, 100))
sol = solve(prob, AutoTsit5(Rosenbrock23()))
end
``````

My question is, essentially, how do I extend this approach for cases where my state variable is not a simple field vector/matrix/array?

For example, consider the following code, in which the state variable is a particle with both a position and a velocity, which themselves are given by `Point2D` structs:

Example where state variable is `PointWithVelocity{T}` having fields `pos::Point2D{T}` and `vel::Point2D{T}`
``````using DifferentialEquations
using StaticArrays

mutable struct Point2D{T} <: FieldVector{2, T}
x::T
y::T
end

# State variable: A point in Cartesian space and its velocity
# I want to be able to access fields by name, e.g. p.pos.x
mutable struct PointWithVelocity{T} # <: WhatGoesHere?{T}
pos::Point2D{T}
vel::Point2D{T}
end

function g!(du::PointWithVelocity, u::PointWithVelocity, _, __)
du.pos.x = u.vel.x
du.pos.y = u.vel.y
du.vel.x = u.vel.x
du.vel.y = u.vel.y - 9.8
end

let u0 = PointWithVelocity(
Point2D(25.0, 30.0),
Point2D(4.0, 0.0)
)
prob = ODEProblem(g!, u0, (0, 100))
sol = solve(prob, AutoTsit5(Rosenbrock23()))
end
``````

Running this example produces the same `no method matching oneunit(::Type{Any})` error as above.

Needless to say, inheriting from `FieldMatrix{2, 2, T}` does not work (and even if it did, this wouldn’t really solve my larger problem because not all of the situations I want to model have a “rectangular” state space that can be arranged as the entries of a matrix):

Same code but with `PointWithVelocity{T} <: FieldMatrix{2, 2, T}`; produces a `BoundsError`
``````using DifferentialEquations
using StaticArrays

mutable struct Point2D{T} <: FieldVector{2, T}
x::T
y::T
end

mutable struct PointWithVelocity{T} <: FieldMatrix{2, 2, T}
pos::Point2D{T}
vel::Point2D{T}
end

function g!(du::PointWithVelocity, u::PointWithVelocity, _, __)
du.pos.x = u.vel.x
du.pos.y = u.vel.y
du.vel.x = u.vel.x
du.vel.y = u.vel.y - 9.8
end

let u0 = PointWithVelocity(
Point2D(25.0, 30.0),
Point2D(4.0, 0.0)
)
prob = ODEProblem(g!, u0, (0, 100))
sol = solve(prob, AutoTsit5(Rosenbrock23()))
end
``````

yields

``````ERROR: BoundsError: attempt to access 2×2 PointWithVelocity{Float64} with indices SOneTo(2)×SOneTo(2) at index [3]
``````

As a workaround, you can use

``````mutable struct PointWithVelocity{T} <: FieldVector{4, T}
pos_x::T
pos_y::T
vel_x::T
vel_y::T
end
``````

but this approach does not scale well to a more complicated system where you have multiple particles, etc.

What’s the best way to use nested structs as the state variable in DifferentialEquations? Is this a wrongheaded approach?

1 Like

That’s the best way.

The key is you need to define enough of the array interface for it to “work”. ComponentArrays is probably the nicest structure for this, and is designed to work without overhead of just using a flat array.

2 Likes

The other thing is to figure out if this is a concern of representation or computation. Are you building a hierarchy to model, or is to achieve some performance end? The two can be different. In that sense, using a symbolic modeling system like ModelingToolkit.jl could be useful to build and interact with a hierarchical model, but without enforcing such a constraint on the computed structure.

https://mtk.sciml.ai/dev/

1 Like

Thank you. ComponentArrays.jl is quite slick indeed. Would be neat if it could use StaticArrays.jl as a backend to achieve similar performance to a flat static array (or does this matter?).

LabelledArrays.jl has SLArray for that. No hierarchy, but it’s not like you can have a huge hierarchy given that static arrays will be less performant by the time you get to about 8 elements.