Aright, improving… The Mtype was not isbits, which obviously hurt performance a lot. Now it’s only 1.5x slower than the SVector version (when skipping get_name_p). There is an ugly hack with the _copy methods. If anyone can tell me how to improve on that, that be awesome.
I’ll try to put a PR for the docs together.
using DiffEqBase, OrdinaryDiffEq, StaticArrays
using Base: @propagate_inbounds
using Base.Broadcast: Broadcasted
using Setfield
using Test, BenchmarkTools
struct Mtype{N,T} <: AbstractVector{T}
data::SVector{N,T}
name::Int
p::Float64
end
Mtype(data::AbstractVector{T}, name::Int, p) where {T} =
Mtype{length(data), T}(SVector{length(data)}(data), name, p)
Base.copy(m::Mtype) = deepcopy(m)
Base.zero(m::Mtype{N,T}) where {N,T} = Mtype{N,T}(zero(SVector{N,T}), m.name, m.p)
# NOTE: that only the array-parts are hashed for <:AbstractVector
tmp1 = Mtype(1:2, -1, 2.0)
tmp2 = Mtype(1:2, -1, 3.0)
@test_broken hash(tmp1)!=hash(tmp2)
@test isbits(tmp1)
# AbstractVector interface
Base.size(var::Mtype{N}) where N = (N,)
@propagate_inbounds Base.getindex(var::Mtype, i::Int) = var.data[i]
Base.IndexStyle(::Mtype) = IndexLinear()
# broadcast interface
struct MtypeStyle{N,T} <: Broadcast.AbstractArrayStyle{1} end
# Whenever you subtype AbstractArrayStyle, you also need to define rules for combining
# dimensionalities, by creating a constructor for your style that takes a Val(N) argument.
MtypeStyle{N,T}(::Val{0}) where {N,T} = MtypeStyle{N,T}()
MtypeStyle{N,T}(::Val{1}) where {N,T} = MtypeStyle{N,T}()
Base.BroadcastStyle(::Type{<:Mtype{N,T}}) where {N,T} = MtypeStyle{N,T}()
function Base.copy(bc::Broadcasted{<:MtypeStyle{N,T}}) where {N,T}
return _copy(bc, bc.args...)
end
# these separate _copy functions speed things up by about 2x.
# But they sure are ugly.
@inline function _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1::Mtype{N,T}, a2::Mtype{N,T}) where {N,T}
f = bc.f
a1.name==a2.name || error("Broadcasting of Mtypes with different name, p not supported.")
a1.p==a2.p || error("Broadcasting of Mtypes with different name, p not supported.")
Mtype{N,T}(broadcast(f, a1.data, a2.data), a1.name, a1.p)
end
@inline _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1::Mtype{N,T}, a2) where {N,T} =
Mtype{N,T}(broadcast(bc.f, a1.data, a2), a1.name, a1.p)
@inline _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1, a2::Mtype{N,T}) where {N,T} =
Mtype{N,T}(broadcast(bc.f, a1, a2.data), a2.name, a2.p)
@inline function _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, args...) where {N,T}
# traverse Broadcasted tree to find Mtype instances
name, p = -1, 2.0 # get_name_p(bc) # this uses about 120μs and 0.4MiB
# This is slow:
# Mtype{N,T}(SVector(tuple((v for v in bc)...)), name, p)
# thus use Setfield.jl trick:
# https://discourse.julialang.org/t/constructing-svector-with-a-loop/15372/7
s = SVector{N,T}(1:N)
for (i,v) in enumerate(bc)
@set! s[i] = v
end
Mtype{N,T}(s, name, p)
end
"""
Retrieve the fields name and p from the Broadcasted
instance and make sure they all match
TODO: a wee bit slow, this needs to be done better.
"""
function get_name_p(bc::Broadcasted)
name::Union{Int, Nothing}=nothing
p::Union{Float64, Nothing}=nothing
for a in bc.args
if a isa Mtype
if name===nothing
name = a.name
p = a.p
else
name==a.name || error("Broadcasting of Mtypes with different name, p not supported.")
p==a.p || error("Broadcasting of Mtypes with different name, p not supported.")
end
elseif a isa Broadcasted
nn, pp = get_name_p(a)
if name===nothing
name = nn
p = pp
elseif nn!==nothing
name==nn || error("Broadcasting of Mtypes with different name, p not supported.")
p==pp || error("Broadcasting of Mtypes with different name, p not supported.")
end
end
end
return name::Int, p::Float64
end
# Test it
rhs_test(f::Mtype, p, t) = f*f.p
f0 = rand(10)
sf0 = SVector{length(f0)}(f0)
foo = Mtype(sf0, -1, 2.0)
tspan = (0.0, 100)
prob = ODEProblem(rhs_test, foo, tspan)
sol = solve(prob, RK4())
@btime solve(prob, RK4()) # 568.960 μs (6553 allocations: 1.32 MiB)
# test against straight SVector version
rhs_test2(f, p, t) = f*2
prob2 = ODEProblem(rhs_test2, sf0, tspan)
sol2 = solve(prob2, RK4())
@btime solve(prob2, RK4()) # 366.731 μs (3615 allocations: 426.84 KiB)
# against Vector version
prob2 = ODEProblem(rhs_test2, f0, tspan)
sol2 = solve(prob2, RK4())
@btime solve(prob2, RK4()) # 2.246 ms (52423 allocations: 7.91 MiB)
@test length(sol)==length(sol2)
for i = 1:length(sol)
@test sol[i]==sol2[i]
end