DifferentialEquations Package Kills Performance Everywhere when TruncatedStacktraces is used

I think I figured out how to improve the printing of ODEFunction in stacktraces. The fact that the number of non-nothing fields in ODEFunction is highly variable indicates that what you really need is a collection, not a struct. So I’ve created an example type called MyFunction that wraps a named tuple. The named tuple contains only the options that are actually passed to the MyFunction constructor. For comparison, I also define a SciMLFunction type that is essentially the same as ODEFunction.

Here is the implementation for SciMLFunction and MyFunction:

struct SciMLFunction{iip, specialize, F, TMM, Ta, Tt,
                   TJ, JVP, VJP, JP, SP, TW, TWt,
                   TPJ, S, S2, S3, O, TCV, SYS}
    f::F
    mass_matrix::TMM
    analytic::Ta
    tgrad::Tt
    jac::TJ
    jvp::JVP
    vjp::VJP
    jac_prototype::JP
    sparsity::SP
    Wfact::TW
    Wfact_t::TWt
    paramjac::TPJ
    syms::S
    indepsym::S2
    paramsyms::S3
    observed::O
    colorvec::TCV
    sys::SYS
end

function SciMLFunction(f; mass_matrix=nothing, analytic=nothing, tgrad=nothing,
                       jac=nothing, jvp=nothing, vjp=nothing, jac_prototype=nothing,
                       sparsity=nothing, Wfact=nothing, Wfact_t=nothing,
                       paramjac=nothing, syms=nothing, indepsym=nothing,
                       paramsyms=nothing, observed=nothing, colorvec=nothing,
                       sys=nothing)
    SciMLFunction{1, 2, typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
    typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym), typeof(paramsyms),
    typeof(observed), typeof(colorvec), typeof(sys)}(f, mass_matrix, analytic, tgrad, jac, jvp, vjp, jac_prototype,
        sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms,
        observed, colorvec, sys)
end

function foo(funcs::SciMLFunction, x)
    g = isnothing(funcs.jac) ? identity : funcs.jac
    h = isnothing(funcs.jvp) ? identity : funcs.jvp
    k = isnothing(funcs.vjp) ? identity : funcs.vjp
    funcs.f(g(h(k(x))))
end

struct MyFunction{F,N,T}
    f::F
    more_funcs::NamedTuple{N,T}
end

function MyFunction(f; kwargs...)
    # Extract the named tuple from the Base.Pairs to reduce
    # the type complexity.
    MyFunction(f, values(kwargs))
end

function foo(funcs::MyFunction, x)
    g = get(funcs.more_funcs, :jac, identity)
    h = get(funcs.more_funcs, :jvp, identity)
    k = get(funcs.more_funcs, :vjp, identity)
    funcs.f(g(h(k(x))))
end

Now let’s define an instance of each of the two types where none of the optional arguments are passed in:

sciml_funcs = SciMLFunction(x -> x + 1)
my_funcs = MyFunction(x -> x + 1)

Here is a screenshot of what the SciMLFunction stacktrace looks like:

Here is a screenshot of what the MyFunction stacktrace looks like:

If you compare frame 2 of the stacktraces, you can see that the printing of the type MyFunction is much shorter than the printing of the type SciMLFunction. I think that’s a considerable improvement.

The same technique could be applied to other types which contain a bunch of optional fields, like DEOptions.

4 Likes