Spline fit from DataInterpolations erroring out with ModelingToolkit symbolic derivative

I’m trying to take a ModelingToolkit symbolic derivative of a variable whose values I’m getting with a DataInterpolations spline, and it’s unable to evaluate the derivative of a constant so I can’t solve the system. MWE:


using ModelingToolkit
using OrdinaryDiffEq
using DataInterpolations
using DataDrivenDiffEq

itp_method = InterpolationMethod(CubicSpline)
xr = 1.0:5.0
@register spl(z)
# originally had just spl = CubicSpline(xr,xr), with the same results
spl(z) = itp_method(xr, xr)(z)

@variables z
@variables v_terminal(z)
Dz = Differential(z)
@variables coll_rate(z)

@named system = ODESystem(
    [
        v_terminal ~ spl(z),
        Dz(coll_rate) ~ Dz(v_terminal)
        # this example is contrived, in the real system I do need Dz(v_terminal)
        # and anything simpler than this gets transformed away by structural_simplify
    ],
    z, [v_terminal, coll_rate], [],
)
system = structural_simplify(system)
prob = ODEProblem(system, [coll_rate => 1.0], (0.0, 1.0))
prob.f(prob.u0, prob.p, 0.0) 
# gives 1-element Vector{Term{Float64, Nothing}}: Differential(z)(0.0)
solve(prob, Tsit5()) # errors out because it can't convert Term{Float64,Nothing} to Float64

I attempted to use the wrapper from DataDrivenDiffEq.jl but it doesn’t seem to have done anything. I tried to fix this example by adding in a variable like dv_dz that I explicitly describe with derivative(spl, z), but that throws MethodError: isless(::Float64, ::Num) is ambiguous. In case that does work, I still have a more complex function for v_terminal whose derivative I need, and I don’t want to have to take its derivative by hand to plug it into the symbolic derivative terms. Is there a better way to handle this situation?

This looks tricky. I though I had a solution, but it fails in a place I don’t know how to fix with ERROR: Differentiation with array expressions is not yet supported.

Not defining a new function and instead using the CubicInterpolation struct gets a little farther. This allows a symbolic derivative to be registered and the derivative function registered as opaque as well. Unfortunately, <:AbstractArray means something different to Symbolics than it does to DataInterpolations, so I’m to sure where to go next.

using ModelingToolkit
using OrdinaryDiffEq
using DataInterpolations
using Symbolics

xr = 1.0:5.0
spl = CubicSpline(xr,xr)

# Define a symbolic derivative of the cubic spline
Symbolics.derivative(s::typeof(spl), args::NTuple{1, Any}, ::Val{1}) = DataInterpolations.derivative(s, args[1])

# register the derivative function so Symbolics does not attempt to trace it.
@register_symbolic DataInterpolations.derivative(itp::DataInterpolations.AbstractInterpolation, val)

@variables z
@variables v_terminal(z)
Dz = Differential(z)
@variables coll_rate(z)

# Now differentiating works...
expand_derivatives(Dz(spl(z)))
# DataInterpolations.derivative([1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0], z)
substitute(expand_derivatives(Dz(spl(z))), z=>1.0)
# 1.0

@named system = ODESystem(
    [
        v_terminal ~ spl(z),
        Dz(coll_rate) ~ Dz(v_terminal)
        # this example is contrived, in the real system I do need Dz(v_terminal)
        # and anything simpler than this gets transformed away by structural_simplify
    ],
    z, [v_terminal, coll_rate], [],
)

# But now errors here because CubicSpline <: AbstractArray
system = structural_simplify(system)

prob = ODEProblem(system, [coll_rate => 1.0], (0.0, 1.0))

prob.f(prob.u0, prob.p, 0.0)

solve(prob, Tsit5())
The new error

julia> system = structural_simplify(system)
ERROR: Differentiation with array expressions is not yet supported
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] occursin_info
@ ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:59 [inlined]
[3] (::Symbolics.var"#210#212"{Term{Real, Nothing}, Term{Real, Nothing}})(a::CubicSpline{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Float64}, Vector{Float64}, true, Float64})
@ Symbolics ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:92
[4] iterate
@ ./generator.jl:47 [inlined]
[5] _collect(c::Vector{Any}, itr::Base.Generator{Vector{Any}, Symbolics.var"#210#212"{Term{Real, Nothing}, Term{Real, Nothing}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base ./array.jl:807
[6] collect_similar(cont::Vector{Any}, itr::Base.Generator{Vector{Any}, Symbolics.var"#210#212"{Term{Real, Nothing}, Term{Real, Nothing}}})
@ Base ./array.jl:716
[7] map(f::Function, A::Vector{Any})
@ Base ./abstractarray.jl:2933
[8] occursin_info(x::Term{Real, Nothing}, expr::Term{Real, Nothing}, fail::Bool)
@ Symbolics ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:92
[9] (::Symbolics.var"#210#212"{Term{Real, Nothing}, SymbolicUtils.Add{Real, Int64, Dict{Any, Number}, Nothing}})(a::Term{Real, Nothing})
@ Symbolics ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:92
[10] iterate
@ ./generator.jl:47 [inlined]
[11] collect_to!(dest::Vector{Term{Real, Nothing}}, itr::Base.Generator{Vector{SymbolicUtils.Symbolic{Real}}, Symbolics.var"#210#212"{Term{Real, Nothing}, SymbolicUtils.Add{Real, Int64, Dict{Any, Number}, Nothing}}}, offs::Int64, st::Int64)
@ Base ./array.jl:845
[12] collect_to_with_first!(dest::Vector{Term{Real, Nothing}}, v1::Term{Real, Nothing}, itr::Base.Generator{Vector{SymbolicUtils.Symbolic{Real}}, Symbolics.var"#210#212"{Term{Real, Nothing}, SymbolicUtils.Add{Real, Int64, Dict{Any, Number}, Nothing}}}, st::Int64)
@ Base ./array.jl:823
[13] _collect(c::Vector{SymbolicUtils.Symbolic{Real}}, itr::Base.Generator{Vector{SymbolicUtils.Symbolic{Real}}, Symbolics.var"#210#212"{Term{Real, Nothing}, SymbolicUtils.Add{Real, Int64, Dict{Any, Number}, Nothing}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base ./array.jl:817
[14] collect_similar(cont::Vector{SymbolicUtils.Symbolic{Real}}, itr::Base.Generator{Vector{SymbolicUtils.Symbolic{Real}}, Symbolics.var"#210#212"{Term{Real, Nothing}, SymbolicUtils.Add{Real, Int64, Dict{Any, Number}, Nothing}}})
@ Base ./array.jl:716
[15] map(f::Function, A::Vector{SymbolicUtils.Symbolic{Real}})
@ Base ./abstractarray.jl:2933
[16] occursin_info(x::Term{Real, Nothing}, expr::SymbolicUtils.Add{Real, Int64, Dict{Any, Number}, Nothing}, fail::Bool)
@ Symbolics ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:92
[17] occursin_info(x::Term{Real, Nothing}, expr::SymbolicUtils.Add{Real, Int64, Dict{Any, Number}, Nothing})
@ Symbolics ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:57
[18] expand_derivatives(O::Term{Real, Nothing}, simplify::Bool; occurances::Nothing)
@ Symbolics ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:169
[19] expand_derivatives(O::Term{Real, Nothing}, simplify::Bool)
@ Symbolics ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:163
[20] jacobian(ops::Vector{SymbolicUtils.Add{Real, Int64, Dict{Any, Number}, Nothing}}, vars::Vector{Any}; simplify::Bool)
@ Symbolics ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:443
[21] jacobian(ops::Vector{SymbolicUtils.Add{Real, Int64, Dict{Any, Number}, Nothing}}, vars::Vector{Any})
@ Symbolics ~/.julia/packages/Symbolics/UrqtQ/src/diff.jl:440
[22] (::ModelingToolkit.StructuralTransformations.var"#141#144"{TearingState{ODESystem}})(eqs::Vector{Int64}, vars::Vector{Int64})
@ ModelingToolkit.StructuralTransformations ~/.julia/packages/ModelingToolkit/jCQlF/src/structural_transformation/symbolics_tearing.jl:729
[23] dummy_derivative_graph!(structure::SystemStructure, var_eq_matching::ModelingToolkit.BipartiteGraphs.Matching{ModelingToolkit.BipartiteGraphs.Unassigned, Vector{Union{ModelingToolkit.BipartiteGraphs.Unassigned, Int64}}}, jac::ModelingToolkit.StructuralTransformations.var"#141#144"{TearingState{ODESystem}}, ::Tuple{ModelingToolkit.AliasGraph, Nothing}, state_priority::Function)
@ ModelingToolkit.StructuralTransformations ~/.julia/packages/ModelingToolkit/jCQlF/src/structural_transformation/partial_state_selection.jl:220
[24] dummy_derivative_graph!(state::TearingState{ODESystem}, jac::Function, ::Tuple{ModelingToolkit.AliasGraph, Nothing}; state_priority::Function, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ModelingToolkit.StructuralTransformations ~/.julia/packages/ModelingToolkit/jCQlF/src/structural_transformation/partial_state_selection.jl:159
[25] dummy_derivative(sys::ODESystem, state::TearingState{ODESystem}, ag::ModelingToolkit.AliasGraph; simplify::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ModelingToolkit.StructuralTransformations ~/.julia/packages/ModelingToolkit/jCQlF/src/structural_transformation/symbolics_tearing.jl:747
[26] _structural_simplify!(state::TearingState{ODESystem}, io::Nothing; simplify::Bool, check_consistency::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ModelingToolkit.SystemStructures ~/.julia/packages/ModelingToolkit/jCQlF/src/systems/systemstructure.jl:542
[27] structural_simplify!(state::TearingState{ODESystem}, io::Nothing; simplify::Bool, check_consistency::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ModelingToolkit.SystemStructures ~/.julia/packages/ModelingToolkit/jCQlF/src/systems/systemstructure.jl:496
[28] structural_simplify(sys::ODESystem, io::Nothing; simplify::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ModelingToolkit ~/.julia/packages/ModelingToolkit/jCQlF/src/systems/systems.jl:39
[29] structural_simplify (repeats 2 times)
@ ~/.julia/packages/ModelingToolkit/jCQlF/src/systems/systems.jl:19 [inlined]
[30] top-level scope
@ REPL[51]:1

Maybe wrapping the interpolation in a new struct that isn’t <: AbstractArray would help? It wouldn’t be very elegant but if it works I don’t mind. I can try this out in a bit

Unfortunately that doesn’t work yet either, there is a PR though.

For now I’ve been able to work around this by explicitly constructing the derivative as its own spline:

using ModelingToolkit
using OrdinaryDiffEq
using DataInterpolations
using Symbolics

xr = 1.0:5.0
spl = CubicSpline(xr,xr)
spl_d = QuadraticSpline(map(x -> DataInterpolations.derivative(spl, x), xr), xr)

@variables z
@variables v_terminal(z)
Dz = Differential(z)
@variables coll_rate(z)

@named system = ODESystem(
    [
        Dz(coll_rate) ~ spl_d(z)
    ],
    z, [coll_rate], [],
)

system = structural_simplify(system)
prob = ODEProblem(system, [coll_rate => 1.0], (0.0, 1.0))
prob.f(prob.u0, prob.p, 0.0)
solve(prob, Tsit5()) # runs as expected, gives coll_rate(t) = t + 1

This makes it difficult to describe complicated functions of splines like my main use case, but it should at least be possible by taking a lot of derivatives by hand.