JuliaDiffEq with custom types

Aright, improving… The Mtype was not isbits, which obviously hurt performance a lot. Now it’s only 1.5x slower than the SVector version (when skipping get_name_p). There is an ugly hack with the _copy methods. If anyone can tell me how to improve on that, that be awesome.

I’ll try to put a PR for the docs together.

using DiffEqBase, OrdinaryDiffEq, StaticArrays
using Base: @propagate_inbounds
using Base.Broadcast: Broadcasted
using Setfield
using Test, BenchmarkTools

struct Mtype{N,T} <: AbstractVector{T}
    data::SVector{N,T}
    name::Int
    p::Float64
end
Mtype(data::AbstractVector{T}, name::Int, p) where {T} =
    Mtype{length(data), T}(SVector{length(data)}(data), name, p)
Base.copy(m::Mtype) = deepcopy(m)
Base.zero(m::Mtype{N,T}) where {N,T} = Mtype{N,T}(zero(SVector{N,T}), m.name, m.p)
# NOTE: that only the array-parts are hashed for <:AbstractVector
tmp1   = Mtype(1:2, -1, 2.0)
tmp2   = Mtype(1:2, -1, 3.0)
@test_broken hash(tmp1)!=hash(tmp2)
@test isbits(tmp1)

# AbstractVector interface
Base.size(var::Mtype{N}) where N = (N,)
@propagate_inbounds Base.getindex(var::Mtype, i::Int) = var.data[i]
Base.IndexStyle(::Mtype) = IndexLinear()

# broadcast interface
struct MtypeStyle{N,T} <: Broadcast.AbstractArrayStyle{1} end
# Whenever you subtype AbstractArrayStyle, you also need to define rules for combining
# dimensionalities, by creating a constructor for your style that takes a Val(N) argument.
MtypeStyle{N,T}(::Val{0}) where {N,T} = MtypeStyle{N,T}()
MtypeStyle{N,T}(::Val{1}) where {N,T} = MtypeStyle{N,T}()

Base.BroadcastStyle(::Type{<:Mtype{N,T}}) where {N,T} = MtypeStyle{N,T}()
function Base.copy(bc::Broadcasted{<:MtypeStyle{N,T}}) where {N,T}
    return _copy(bc, bc.args...)
end
# these separate _copy functions speed things up by about 2x.
# But they sure are ugly.
@inline function _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1::Mtype{N,T}, a2::Mtype{N,T}) where {N,T}
    f = bc.f
    a1.name==a2.name || error("Broadcasting of Mtypes with different name, p not supported.")
    a1.p==a2.p || error("Broadcasting of Mtypes with different name, p not supported.")
    Mtype{N,T}(broadcast(f, a1.data, a2.data), a1.name, a1.p)
end
@inline _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1::Mtype{N,T}, a2) where {N,T} =
    Mtype{N,T}(broadcast(bc.f, a1.data, a2), a1.name, a1.p)
@inline _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1, a2::Mtype{N,T}) where {N,T} =
    Mtype{N,T}(broadcast(bc.f, a1, a2.data), a2.name, a2.p)
@inline function _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, args...) where {N,T}
    # traverse Broadcasted tree to find Mtype instances
    name, p = -1, 2.0 # get_name_p(bc) # this uses about 120μs and 0.4MiB

    # This is slow:
    # Mtype{N,T}(SVector(tuple((v for v in bc)...)), name, p)
    # thus use Setfield.jl trick:
    # https://discourse.julialang.org/t/constructing-svector-with-a-loop/15372/7
    s = SVector{N,T}(1:N)
    for (i,v) in enumerate(bc)
        @set! s[i] = v
    end
    Mtype{N,T}(s, name, p)
end


"""
Retrieve the fields name and p from the Broadcasted
instance and make sure they all match

TODO: a wee bit slow, this needs to be done better.
"""
function get_name_p(bc::Broadcasted)
    name::Union{Int, Nothing}=nothing
    p::Union{Float64, Nothing}=nothing
    for a in bc.args
        if a isa Mtype
            if name===nothing
                name = a.name
                p = a.p
            else
                name==a.name || error("Broadcasting of Mtypes with different name, p not supported.")
                p==a.p || error("Broadcasting of Mtypes with different name, p not supported.")
            end
        elseif a isa Broadcasted
            nn, pp = get_name_p(a)
            if name===nothing
                name = nn
                p = pp
            elseif nn!==nothing
                name==nn || error("Broadcasting of Mtypes with different name, p not supported.")
                p==pp || error("Broadcasting of Mtypes with different name, p not supported.")
            end
        end
    end
    return name::Int, p::Float64
end

# Test it
rhs_test(f::Mtype, p, t) = f*f.p

f0    = rand(10)
sf0   = SVector{length(f0)}(f0)
foo   = Mtype(sf0, -1, 2.0)

tspan = (0.0, 100)

prob  = ODEProblem(rhs_test, foo, tspan)
sol   = solve(prob, RK4())
@btime solve(prob, RK4()) #  568.960 μs (6553 allocations: 1.32 MiB)

# test against straight SVector version
rhs_test2(f, p, t) = f*2
prob2  = ODEProblem(rhs_test2, sf0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4()) # 366.731 μs (3615 allocations: 426.84 KiB)

# against Vector version
prob2  = ODEProblem(rhs_test2, f0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4()) # 2.246 ms (52423 allocations: 7.91 MiB)


@test length(sol)==length(sol2)
for i = 1:length(sol)
    @test sol[i]==sol2[i]
end