Problem with ForwardDiff and a custom vector type in OrdinaryDiffEq

I have an ODE system tracking the densities of a number of species. If certain criteria are fulfilled I add more species to the system and run the new slightly bigger ODE system, and repeat this process until certain other criteria are fulfilled.

The general problem: I’m using ComponentArrays to access different variables of the state vector, and while it’s an excellent library, it stores the indexing information for the views in the type domain which means that every time I add species, every function has to be recompiled. Since the run time of the differential equations is quite short, it means that I often wait through several minutes of compile time for a problem that on the second go is computed in much less than a second. To avoid an XY problem, if anyone knows any good way of getting around this that is not what I’ve tried below, I’d be grateful.

Attempted solution: I’m trying to implement my own (much more limited) version of a ComponentVector that stores the indexing information in the value domain rather than the type domain. It’s made specifically for my system and should be able to reference a vector N, a matrix x and another vector E. Since this custom vector type needs to work with DifferentialEquations.jl, it needs to correctly subtype an AbstractVector and define some broadcasting behavior. I mostly write application code, and I feel somewhat in over my head with this, but this is what I have thus far:

module MCV

# Define componentvector-like struct
struct MyComponentVector{T} <: AbstractVector{T}

# Define stuff to make MyComponentVector behave like a vector
Base.IndexStyle(::Type{<:MyComponentVector}) = IndexLinear()
Base.size(u::MyComponentVector) = size(getfield(u, :data)) = MyComponentVector(zero(getfield(u, :data)), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.getindex(u::MyComponentVector, i::Integer) = getindex(getfield(u, :data), i)
Base.setindex!(u::MyComponentVector, v, i::Integer) = setindex!(getfield(u, :data), v, i)
Base.copy(u::MyComponentVector) = MyComponentVector(copy(getfield(u, :data)), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.similar(u::MyComponentVector) = MyComponentVector(similar(getfield(u, :data)), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.similar(u::MyComponentVector, ::Type{T}) where T = MyComponentVector(similar(getfield(u, :data), T), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.similar(u::MyComponentVector, ::Type{T}, dims::Dims{1}) where {T} = MyComponentVector(similar(getfield(u, :data), T, dims), getfield(u, :N_idxs), getfield(u, :x_idxs), getfield(u, :E_idxs), getfield(u, :x_shape))
Base.similar(u::MyComponentVector, ::Type{T}, dims::Dims) where {T} = similar(getfield(u, :data), T, dims)

# Define basic broadcasting behavior
struct MyComponentVectorStyle <: Broadcast.AbstractArrayStyle{1} end
MyComponentVectorStyle(::Val{0}) = MyComponentVectorStyle()
MyComponentVectorStyle(::Val{1}) = MyComponentVectorStyle()
MyComponentVectorStyle(::Val{N}) where N = Base.Broadcast.DefaultArrayStyle{N}()

Base.BroadcastStyle(::Type{<:MyComponentVector}) = MyComponentVectorStyle()

# Adapted from the Julia manual
function Base.similar(bc::Broadcast.Broadcasted{MyComponentVectorStyle}, ::Type{ElType}) where ElType
    # Scan the inputs for the MyComponentVector:
    mcv = find_mcv(bc)
    # Use the MyComponentVector to create output
find_mcv(bc::Base.Broadcast.Broadcasted) = find_mcv(bc.args)
find_mcv(args::Tuple) = find_mcv(find_mcv(args[1]), Base.tail(args))
find_mcv(x) = x
find_mcv(::Tuple{}) = nothing
find_mcv(mcv::MyComponentVector, rest) = mcv
find_mcv(::Any, rest) = find_mcv(rest)

# Define main constructor and index arithmetic
function MyComponentVector(; N, x, E)
    NS = length(N)
    NT = size(x, 1)
    NE = length(E)
    data = vcat(N, x[:], E)
    N_idxs = get_N_idxs(NS)
    x_idxs = get_x_idxs(NS, NT)
    E_idxs = get_E_idxs(NS, NT, NE)
    x_shape = size(x)
    MyComponentVector(data, N_idxs, x_idxs, E_idxs, x_shape)

get_N_idxs(NS) = 1:NS 
get_x_idxs(NS, NT) = NS + 1 : NS + NS*NT 
get_E_idxs(NS, NT, NE) = NS + NS*NT + 1 : NS + NS*NT + NE 

# Define behavior to get the three components via . access
Base.getproperty(u::MyComponentVector, ::Val{:N}) = view(getfield(u, :data), getfield(u, :N_idxs))
Base.getproperty(u::MyComponentVector, ::Val{:x}) = reshape(view(getfield(u, :data), getfield(u, :x_idxs)), getfield(u, :x_shape))
Base.getproperty(u::MyComponentVector, ::Val{:E}) = view(getfield(u, :data), getfield(u, :E_idxs))
Base.getproperty(u::MyComponentVector, sym::Symbol) = getproperty(u, Val(sym))

end # module 

Specific problem with attempted solution: While I’ve gotten this to work with explicit solvers (e.g., Tsit5()), I cannot get it to work with implicit solvers (I’d like to use Rodas5P()). Here is a minimal example reproducing the error (requires loading the code above too):

using DifferentialEquations

function mwe_ode_system!(du, u, p, t)

    (; N, x, E) = u; 
    dN = du.N; dx = du.x; dE = du.E

    dN .= p.AN*N 
    dx .= p.Ax*x 
    dE .= p.AE*E 

    return nothing

N0 = [1.0, 2.0]
x0 = [3.0; 4.0; 5.0;; 6.0; 7.0; 8.0]
E0 = [9.0, 10.0, 11.0, 12.0]
data0 = vcat(N0, x0[:], E0)

mcv = MCV.MyComponentVector(; N = N0, x = x0, E = E0)

AN = 0.346183

Ax = [0.227105  -0.175687   -0.345761
      0.277979   0.0391113  -0.107145
     -0.278168   0.0757189  -0.0320952]

AE = [0.235112    0.0252358    0.390814  -0.22144
     -0.368158   -0.329906    -0.244081  -0.105781
     -0.0677509   0.00664437   0.18053   -0.143473
     -0.0678699  -0.317413     0.30792   -0.137691]

p = (; AN, Ax, AE)

prob = ODEProblem{true}(mwe_ode_system!, mcv, (0.0, 5.0), p)
sol = solve(prob, Rodas5P(); abstol = 1e-6, reltol = 1e-6)

This produces the following error:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:112

  [1] convert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{…}, Float64, 1}, i1::Int64)
    @ Base ./array.jl:1021
  [3] setindex!(u::Main.MCV.MyComponentVector{Float64}, v::ForwardDiff.Dual{ForwardDiff.Tag{…}, Float64, 1}, i::Int64)
    @ Main.MCV ~/Files/Documents/Dropbox/Projects/Implicit trade-offs/Julia/scratches_components.jl:23
  [4] macro expansion
    @ ./broadcast.jl:1004 [inlined]
  [5] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [6] copyto!
    @ ./broadcast.jl:1003 [inlined]
  [7] copyto!
    @ ./broadcast.jl:956 [inlined]
  [8] copy
    @ ./broadcast.jl:928 [inlined]
  [9] materialize
    @ ./broadcast.jl:903 [inlined]
 [10] build_grad_config(alg::Rodas5P{…}, f::ODEFunction{…}, tf::SciMLBase.TimeGradientWrapper{…}, du1::Main.MCV.MyComponentVector{…}, t::Float64)
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/derivative_wrappers.jl:375
 [11] alg_cache(alg::Rodas5P{…}, u::Main.MCV.MyComponentVector{…}, rate_prototype::Main.MCV.MyComponentVector{…}, ::Type{…}, ::Type{…}, ::Type{…}, uprev::Main.MCV.MyComponentVector{…}, uprev2::Main.MCV.MyComponentVector{…}, f::ODEFunction{…}, t::Float64, dt::Float64, reltol::Float64, p::@NamedTuple{…}, calck::Bool, ::Val{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/caches/rosenbrock_caches.jl:1114
 [12] __init(prob::ODEProblem{…}, alg::Rodas5P{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Rational{…}, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:350
 [13] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:11 [inlined]
 [14] #__solve#670
    @ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:6 [inlined]
 [15] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/Knuk0/src/solve.jl:1 [inlined]
 [16] solve_call(_prob::ODEProblem{…}, args::Rodas5P{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:612
 [17] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Main.MCV.MyComponentVector{…}, p::@NamedTuple{…}, args::Rodas5P{…}; kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1080
 [18] solve_up
    @ ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1066 [inlined]
 [19] solve(prob::ODEProblem{…}, args::Rodas5P{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/DS1sd/src/solve.jl:1003
 [20] top-level scope
    @ REPL[11]:1
Some type information was truncated. Use `show(err)` to see complete types.

I tried consulting the SciML documentation but couldn’t find anything regarding how to implement a custom vector type. There must be some missing or incorrectly specified piece of my vector type that makes it not work properly with the DifferentialEquations machinery, but I cannot figure out what it is.

About the general problem, excessive type-specialization of ComponentArrays.jl is also something I have run into, so it’s great that you’re looking into potential solutions!

The specific bug you encounter here is autodiff-related. To compute derivatives, ForwardDiff.jl propagates so-called Dual numbers through your code in lieu of normal numbers. In particular, it requires any differentiated code to accept generic number types.
What the stack trace reveals is that somewhere in your code (or in the ODE solver), a storage is created with element type Float64, but then ForwardDiff.jl tries and fails to put a Dual number inside. Normally this should not happen because OrdinaryDiffEq.jl is clever enough to propagate the correct number types.
My best guess is that there is something wrong with your similar or zero methods: perhaps you missed a signature and it defaults to Float64 somewhere, the same way zeros defaults to zeros(Float64, ...).

Yes, sorry for not being explicit about this in the question and thanks for pointing it out. Something goes wrong when the stiff solver algorithms in DifferentialEquations.jl tries to set something up using autodiff, but I cannot figure out what it is. It is not immediately obvious what this problem is, as I can use ForwardDiff.jl to differnetiate the function just fine:

function add_stuff(mcv)

    s = 0.0

    s = s + mcv.N[1]
    s = s + mcv.x[1,1]
    s = s + mcv.E[1]

    return s


ForwardDiff.gradient(add_stuff, mcv) # Works, returns a MyComponentVector

dmcv = similar(mcv)
ForwardDiff.jacobian((du, u) -> mwe_ode_system!(du, u, p, 0.0), dmcv, mcv) 
# Also works fine, returns Matrix{Float64}

where I’ve used same variables as defined in my previous post.

I’m not familiar enough with ODEs to help, but I renamed your post to something that will catch more expert eyes.

1 Like

Do @show typeof(t), typeof(u), typeof(du), what is the element type when it’s erroring? My guess is that it’s Float64 when it should be dual.

The generic type construction is done via:

If du is not the expected type, I would check the action of ArrayInterface.restructure on your type.

1 Like

Thanks Chris! That did help me figure out what was wrong. I made a mistake in how I implemented broadcasting, so the element type could note be changed. If anyone is looking at this in the future, this is the old incorrect code:

function Base.similar(bc::Broadcast.Broadcasted{MyComponentVectorStyle}, ::Type{ElType}) where ElType
    # Scan the inputs for the MyComponentVector:
    mcv = find_mcv(bc)
    # Use the MyComponentVector to create output

and this is what it should look like:

function Base.similar(bc::Broadcast.Broadcasted{MyComponentVectorStyle}, ::Type{ElType}) where ElType
    # Scan the inputs for the MyComponentVector:
    mcv = find_mcv(bc)
    # Use the MyComponentVector to create output
    similar(mcv, ElType)

It seems like it’s all working, so now I can move on to test the compile-time/run-time trade-off for my actual problem.