I have an ODE system tracking the densities of a number of species. If certain criteria are fulfilled I add more species to the system and run the new slightly bigger ODE system, and repeat this process until certain other criteria are fulfilled.
The general problem: I’m using ComponentArrays
to access different variables of the state vector, and while it’s an excellent library, it stores the indexing information for the views in the type domain which means that every time I add species, every function has to be recompiled. Since the run time of the differential equations is quite short, it means that I often wait through several minutes of compile time for a problem that on the second go is computed in much less than a second. To avoid an XY problem, if anyone knows any good way of getting around this that is not what I’ve tried below, I’d be grateful.
Attempted solution: I’m trying to implement my own (much more limited) version of a ComponentVector
that stores the indexing information in the value domain rather than the type domain. It’s made specifically for my system and should be able to reference a vector N
, a matrix x
and another vector E
. Since this custom vector type needs to work with DifferentialEquations.jl, it needs to correctly subtype an AbstractVector
and define some broadcasting behavior. I mostly write application code, and I feel somewhat in over my head with this, but this is what I have thus far:
module MCV
# Define componentvector-like struct
struct MyComponentVector{T} <: AbstractVector{T}
# Define stuff to make MyComponentVector behave like a vector
Base.IndexStyle(::Type{<:MyComponentVector}) = IndexLinear()
Base.size(u::MyComponentVector) = size(getfield(u, :data))
Base.zero(u::MyComponentVector) = MyComponentVector(zero(getfield(u, :data)), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.getindex(u::MyComponentVector, i::Integer) = getindex(getfield(u, :data), i)
Base.setindex!(u::MyComponentVector, v, i::Integer) = setindex!(getfield(u, :data), v, i)
Base.copy(u::MyComponentVector) = MyComponentVector(copy(getfield(u, :data)), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.similar(u::MyComponentVector) = MyComponentVector(similar(getfield(u, :data)), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.similar(u::MyComponentVector, ::Type{T}) where T = MyComponentVector(similar(getfield(u, :data), T), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.similar(u::MyComponentVector, ::Type{T}, dims::Dims{1}) where {T} = MyComponentVector(similar(getfield(u, :data), T, dims), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.similar(u::MyComponentVector, ::Type{T}, dims::Dims) where {T} = similar(getfield(u, :data), T, dims)
# Define basic broadcasting behavior
struct MyComponentVectorStyle <: Broadcast.AbstractArrayStyle{1} end
MyComponentVectorStyle(::Val{0}) = MyComponentVectorStyle()
MyComponentVectorStyle(::Val{1}) = MyComponentVectorStyle()
MyComponentVectorStyle(::Val{N}) where N = Base.Broadcast.DefaultArrayStyle{N}()
Base.BroadcastStyle(::Type{<:MyComponentVector}) = MyComponentVectorStyle()
# Adapted from the Julia manual
function Base.similar(bc::Broadcast.Broadcasted{MyComponentVectorStyle}, ::Type{ElType}) where ElType
# Scan the inputs for the MyComponentVector:
mcv = find_mcv(bc)
# Use the MyComponentVector to create output
find_mcv(bc::Base.Broadcast.Broadcasted) = find_mcv(bc.args)
find_mcv(args::Tuple) = find_mcv(find_mcv(args[1]), Base.tail(args))
find_mcv(x) = x
find_mcv(::Tuple{}) = nothing
find_mcv(mcv::MyComponentVector, rest) = mcv
find_mcv(::Any, rest) = find_mcv(rest)
# Define main constructor and index arithmetic
function MyComponentVector(; N, x, E)
NS = length(N)
NT = size(x, 1)
NE = length(E)
data = vcat(N, x[:], E)
N_idxs = get_N_idxs(NS)
x_idxs = get_x_idxs(NS, NT)
E_idxs = get_E_idxs(NS, NT, NE)
x_shape = size(x)
MyComponentVector(data, N_idxs, x_idxs, E_idxs, x_shape)
get_N_idxs(NS) = 1:NS
get_x_idxs(NS, NT) = NS + 1 : NS + NS*NT
get_E_idxs(NS, NT, NE) = NS + NS*NT + 1 : NS + NS*NT + NE
# Define behavior to get the three components via . access
Base.getproperty(u::MyComponentVector, ::Val{:N}) = view(getfield(u, :data), getfield(u, :N_idxs))
Base.getproperty(u::MyComponentVector, ::Val{:x}) = reshape(view(getfield(u, :data), getfield(u, :x_idxs)), getfield(u, :x_shape))
Base.getproperty(u::MyComponentVector, ::Val{:E}) = view(getfield(u, :data), getfield(u, :E_idxs))
Base.getproperty(u::MyComponentVector, sym::Symbol) = getproperty(u, Val(sym))
end # module
Specific problem with attempted solution: While I’ve gotten this to work with explicit solvers (e.g., Tsit5()
), I cannot get it to work with implicit solvers (I’d like to use Rodas5P()
). Here is a minimal example reproducing the error (requires loading the code above too):
using DifferentialEquations
function mwe_ode_system!(du, u, p, t)
(; N, x, E) = u;
dN = du.N; dx = du.x; dE = du.E
dN .= p.AN*N
dx .= p.Ax*x
dE .= p.AE*E
return nothing
N0 = [1.0, 2.0]
x0 = [3.0; 4.0; 5.0;; 6.0; 7.0; 8.0]
E0 = [9.0, 10.0, 11.0, 12.0]
data0 = vcat(N0, x0[:], E0)
mcv = MCV.MyComponentVector(; N = N0, x = x0, E = E0)
AN = 0.346183
Ax = [0.227105 -0.175687 -0.345761
0.277979 0.0391113 -0.107145
-0.278168 0.0757189 -0.0320952]
AE = [0.235112 0.0252358 0.390814 -0.22144
-0.368158 -0.329906 -0.244081 -0.105781
-0.0677509 0.00664437 0.18053 -0.143473
-0.0678699 -0.317413 0.30792 -0.137691]
p = (; AN, Ax, AE)
prob = ODEProblem{true}(mwe_ode_system!, mcv, (0.0, 5.0), p)
sol = solve(prob, Rodas5P(); abstol = 1e-6, reltol = 1e-6)
This produces the following error:
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1})
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
@ Base rounding.jl:207
(::Type{T})(::T) where T<:Number
@ Core boot.jl:792
@ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:112
[1] convert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1})
@ Base ./number.jl:7
[2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{…}, Float64, 1}, i1::Int64)
@ Base ./array.jl:1021
[3] setindex!(u::Main.MCV.MyComponentVector{Float64}, v::ForwardDiff.Dual{ForwardDiff.Tag{…}, Float64, 1}, i::Int64)
@ Main.MCV ~/Files/Documents/Dropbox/Projects/Implicit trade-offs/Julia/scratches_components.jl:23
[4] macro expansion
@ ./broadcast.jl:1004 [inlined]
[5] macro expansion
@ ./simdloop.jl:77 [inlined]
[6] copyto!
@ ./broadcast.jl:1003 [inlined]
[7] copyto!
@ ./broadcast.jl:956 [inlined]
[8] copy
@ ./broadcast.jl:928 [inlined]
[9] materialize
@ ./broadcast.jl:903 [inlined]
[10] build_grad_config(alg::Rodas5P{…}, f::ODEFunction{…}, tf::SciMLBase.TimeGradientWrapper{…}, du1::Main.MCV.MyComponentVector{…}, t::Float64)
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/derivative_wrappers.jl:375
[11] alg_cache(alg::Rodas5P{…}, u::Main.MCV.MyComponentVector{…}, rate_prototype::Main.MCV.MyComponentVector{…}, ::Type{…}, ::Type{…}, ::Type{…}, uprev::Main.MCV.MyComponentVector{…}, uprev2::Main.MCV.MyComponentVector{…}, f::ODEFunction{…}, t::Float64, dt::Float64, reltol::Float64, p::@NamedTuple{…}, calck::Bool, ::Val{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/caches/rosenbrock_caches.jl:1114
[12] __init(prob::ODEProblem{…}, alg::Rodas5P{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Rational{…}, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:350
[13] __init (repeats 5 times)
@ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:11 [inlined]
[14] #__solve#670
@ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:6 [inlined]
[15] __solve
@ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:1 [inlined]
[16] solve_call(_prob::ODEProblem{…}, args::Rodas5P{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:612
[17] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Main.MCV.MyComponentVector{…}, p::@NamedTuple{…}, args::Rodas5P{…}; kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1080
[18] solve_up
@ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1066 [inlined]
[19] solve(prob::ODEProblem{…}, args::Rodas5P{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1003
[20] top-level scope
@ REPL[11]:1
Some type information was truncated. Use `show(err)` to see complete types.
I tried consulting the SciML documentation but couldn’t find anything regarding how to implement a custom vector type. There must be some missing or incorrectly specified piece of my vector type that makes it not work properly with the DifferentialEquations machinery, but I cannot figure out what it is.