JuliaDiffEq with custom types

You might want to make this <: StaticArray. That might make it have some better dispatches. There are likely some other ways to speed this up too.

The problem with subtyping from StaticVector is that then I need to define similar_type https://juliaarrays.github.io/StaticArrays.jl/stable/pages/api.html#Implementing-your-own-types-1. But that assumes, as far as I can tell, that all the fields are part of the array, i.e. there cannot be any metadata name, p.

Otherwise that would be perfect (or similarly working with FieldVector).

Edit: x-ref https://github.com/JuliaArrays/StaticArrays.jl/issues/592

This recovers the fastness of Static arrays, using Setfield.jl

function Base.copy(bc::Broadcasted{<:MtypeStyle{N,T}}) where {N,T}
    # traverse Broadcasted tree to find Mtype instances
    name, p = get_name_p(bc)
    # This trick works:
    # https://discourse.julialang.org/t/constructing-svector-with-a-loop/15372/7
    s = SVector{N,T}(1:N)
    for (i,v) in enumerate(bc)
        @set s[i] = v
    end
    Mtype{N,T}(s, name, p)
end

x-ref Constructing SVector with a loop - #7 by tkf

2 Likes

Would something like this be a good addition to the docs? Where would it go?

I think you need to use s = @set s[i] = v or its shortcut @set! s[i] = v.

It would be a good addition to the docs, probably an FAQ on how define state types, or a discussion on the problem interface page

1 Like

Oops, you’re right. This makes timing worse, but only about a factor 20x slower instead of 200x (and still ~2x slower than vector).

Aright, improving… The Mtype was not isbits, which obviously hurt performance a lot. Now it’s only 1.5x slower than the SVector version (when skipping get_name_p). There is an ugly hack with the _copy methods. If anyone can tell me how to improve on that, that be awesome.

I’ll try to put a PR for the docs together.

using DiffEqBase, OrdinaryDiffEq, StaticArrays
using Base: @propagate_inbounds
using Base.Broadcast: Broadcasted
using Setfield
using Test, BenchmarkTools

struct Mtype{N,T} <: AbstractVector{T}
    data::SVector{N,T}
    name::Int
    p::Float64
end
Mtype(data::AbstractVector{T}, name::Int, p) where {T} =
    Mtype{length(data), T}(SVector{length(data)}(data), name, p)
Base.copy(m::Mtype) = deepcopy(m)
Base.zero(m::Mtype{N,T}) where {N,T} = Mtype{N,T}(zero(SVector{N,T}), m.name, m.p)
# NOTE: that only the array-parts are hashed for <:AbstractVector
tmp1   = Mtype(1:2, -1, 2.0)
tmp2   = Mtype(1:2, -1, 3.0)
@test_broken hash(tmp1)!=hash(tmp2)
@test isbits(tmp1)

# AbstractVector interface
Base.size(var::Mtype{N}) where N = (N,)
@propagate_inbounds Base.getindex(var::Mtype, i::Int) = var.data[i]
Base.IndexStyle(::Mtype) = IndexLinear()

# broadcast interface
struct MtypeStyle{N,T} <: Broadcast.AbstractArrayStyle{1} end
# Whenever you subtype AbstractArrayStyle, you also need to define rules for combining
# dimensionalities, by creating a constructor for your style that takes a Val(N) argument.
MtypeStyle{N,T}(::Val{0}) where {N,T} = MtypeStyle{N,T}()
MtypeStyle{N,T}(::Val{1}) where {N,T} = MtypeStyle{N,T}()

Base.BroadcastStyle(::Type{<:Mtype{N,T}}) where {N,T} = MtypeStyle{N,T}()
function Base.copy(bc::Broadcasted{<:MtypeStyle{N,T}}) where {N,T}
    return _copy(bc, bc.args...)
end
# these separate _copy functions speed things up by about 2x.
# But they sure are ugly.
@inline function _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1::Mtype{N,T}, a2::Mtype{N,T}) where {N,T}
    f = bc.f
    a1.name==a2.name || error("Broadcasting of Mtypes with different name, p not supported.")
    a1.p==a2.p || error("Broadcasting of Mtypes with different name, p not supported.")
    Mtype{N,T}(broadcast(f, a1.data, a2.data), a1.name, a1.p)
end
@inline _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1::Mtype{N,T}, a2) where {N,T} =
    Mtype{N,T}(broadcast(bc.f, a1.data, a2), a1.name, a1.p)
@inline _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, a1, a2::Mtype{N,T}) where {N,T} =
    Mtype{N,T}(broadcast(bc.f, a1, a2.data), a2.name, a2.p)
@inline function _copy(bc::Broadcasted{<:MtypeStyle{N,T}}, args...) where {N,T}
    # traverse Broadcasted tree to find Mtype instances
    name, p = -1, 2.0 # get_name_p(bc) # this uses about 120μs and 0.4MiB

    # This is slow:
    # Mtype{N,T}(SVector(tuple((v for v in bc)...)), name, p)
    # thus use Setfield.jl trick:
    # https://discourse.julialang.org/t/constructing-svector-with-a-loop/15372/7
    s = SVector{N,T}(1:N)
    for (i,v) in enumerate(bc)
        @set! s[i] = v
    end
    Mtype{N,T}(s, name, p)
end


"""
Retrieve the fields name and p from the Broadcasted
instance and make sure they all match

TODO: a wee bit slow, this needs to be done better.
"""
function get_name_p(bc::Broadcasted)
    name::Union{Int, Nothing}=nothing
    p::Union{Float64, Nothing}=nothing
    for a in bc.args
        if a isa Mtype
            if name===nothing
                name = a.name
                p = a.p
            else
                name==a.name || error("Broadcasting of Mtypes with different name, p not supported.")
                p==a.p || error("Broadcasting of Mtypes with different name, p not supported.")
            end
        elseif a isa Broadcasted
            nn, pp = get_name_p(a)
            if name===nothing
                name = nn
                p = pp
            elseif nn!==nothing
                name==nn || error("Broadcasting of Mtypes with different name, p not supported.")
                p==pp || error("Broadcasting of Mtypes with different name, p not supported.")
            end
        end
    end
    return name::Int, p::Float64
end

# Test it
rhs_test(f::Mtype, p, t) = f*f.p

f0    = rand(10)
sf0   = SVector{length(f0)}(f0)
foo   = Mtype(sf0, -1, 2.0)

tspan = (0.0, 100)

prob  = ODEProblem(rhs_test, foo, tspan)
sol   = solve(prob, RK4())
@btime solve(prob, RK4()) #  568.960 μs (6553 allocations: 1.32 MiB)

# test against straight SVector version
rhs_test2(f, p, t) = f*2
prob2  = ODEProblem(rhs_test2, sf0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4()) # 366.731 μs (3615 allocations: 426.84 KiB)

# against Vector version
prob2  = ODEProblem(rhs_test2, f0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4()) # 2.246 ms (52423 allocations: 7.91 MiB)


@test length(sol)==length(sol2)
for i = 1:length(sol)
    @test sol[i]==sol2[i]
end

On this topic, there’s currently an issue here that ArrayPartition is not compatible with esp. stiff solvers since linear algebra is not implemented for it, (it’s mentioned that PartitionedODEProblem use them internally so maybe it’s also affected? btw, PartitionedODEProblem not defined appears to be undefined, looks like code is absent)

Is there a roadmap about what needs to be done if one wishes to help with it (overloading the functions with ArrayPartition or copy it in a linearized way to a cache)? It can be very useful for a set of coupled odes that included matrices, vectors and numbers, e.g.

v’ = fv(m_i,v,s)
m_i’ = fm(m_j,v,s)
s’ = fs(m_i,v,s)
where v is a vector, m_{1…n} are matrices and s is a number.

right now, I’m handling it through reshaping v,m,s to 1d and vcat them together, speed is okay for small systems but for bigger systems it may require preallocating and passing the preallocated arrays via the p parameter (may still generate some allocations).

By the way, I suppose s = SVector{N,T}(ntuple(i -> (@inbounds bc[i]), Val(N))) is equivalent to the @set! solution in terms of type stability?

1 Like

For this case we would really want to do blocked linear algebra, but I am not sure of a good solution without something like Swizzles. We may just have to densify until we have more linear algebra tools.

1 Like

Nice! I didn’t know about the Val variant of ntuple. This is the fastest so far (along with unrolled tuple construction). This gets it to 1.2x of StaticArray and to 1.4x without those special-case _copy methods.

Interesting! I thought it would be just equivalent to @set!.

Yep, ntuple seems a little faster. For larger input vectors the difference increases. I think it is many due to the fact that ntuple unrolls the loop. If I manually unroll the Setfield loop, it’s the same speed.

1 Like

Sorry for resurrecting this old thread, but I noticed that my custom array type was performing rather badly (when compared with standard arrays) with DifferentialEquations. I tried to define the broadcasting operations following the docs, but there’s something I’m not doing right… Here’s what I have:

struct MyArr{A,N} <: AbstractArray{A,N}
    data  :: Array{A,N}
    name  :: String
end
MyArr(data::AbstractArray{T,N}, name::String) where {T,N} =
    MyArr{eltype(data),N}(data, name)

@inline Base.size(arr::MyArr) = size(arr.data)
@inline Base.length(arr::MyArr) = length(arr.data)

@inline Base.getindex(arr::MyArr, i::Int) = arr.data[i]
@inline Base.getindex(arr::MyArr, I::Vararg{Int,N}) where {N} = arr.data[I...]
@inline Base.getindex(arr::MyArr, ::Colon) = arr.data[:]
@inline Base.getindex(arr::MyArr, kr::AbstractRange) = arr.data[kr]

@inline Base.setindex!(arr::MyArr, v, i::Int) = (arr.data[i] = v)
@inline Base.setindex!(arr::MyArr, v, I::Vararg{Int,N}) where {N} = (arr.data[I...] = v)
@inline Base.setindex!(arr::MyArr, v, ::Colon) = (arr.data[:] .= v)
@inline Base.setindex!(arr::MyArr, v, kr::AbstractRange) = (arr.data[kr] .= v)

@inline Base.iterate(arr::MyArr) = iterate(arr.data)
@inline Base.iterate(arr::MyArr, state) = iterate(arr.data, state)

Base.show(io::IO, arr::MyArr) = show(io, arr.data)

Base.IndexStyle(arr::MyArr) = Base.IndexStyle(arr.data)

function Base.similar(arr::MyArr)
    data   = similar(arr.data)
    MyArr(data, arr.name)
end
function Base.similar(arr::MyArr, ::Type{T}) where {T}
    data   = similar(arr.data)
    MyArr(similar(data,T),arr.name)
end

Base.:*(x::Number, f::MyArr) = MyArr(x*f.data, f.name)
Base.:*(f::MyArr, x::Number) = MyArr(x*f.data, f.name)
Base.:/(f::MyArr, x::Number) = MyArr(f.data/x, f.name)
Base.:\(x::Number, f::MyArr) = MyArr(f.data/x, f.name)

# let's take the left name. they should match, but i'm not sure if checking
# for that is costly
Base.:+(f::MyArr, g::MyArr) = MyArr(f.data+g.data, f.name)
Base.:-(f::MyArr, g::MyArr) = MyArr(f.data-g.data, f.name)

Base.BroadcastStyle(::Type{<:MyArr}) = Broadcast.ArrayStyle{MyArr}()

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArr}},
    ::Type{T}) where T
    # Scan the inputs for MyArr:
    arr = find_arr(bc)
    # Use the name field of MyArr to create the output
    MyArr(similar(Array{T}, axes(bc)), arr.name)
end

find_arr(bc::Base.Broadcast.Broadcasted) = find_arr(bc.args)
find_arr(args::Tuple) = find_arr(find_arr(args[1]), Base.tail(args))
find_arr(x) = x
find_arr(a::MyArr, rest) = a
find_arr(::Any, rest) = find_arr(rest)

@inline function Base.copyto!(dest::Array,
    bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArr}})
    arr = find_arr(bc)
    copyto!(dest, arr.data)
end

@inline function Base.copyto!(dest::MyArr,
    bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArr}})
    arr = find_arr(bc)
    copyto!(dest.data, arr.data)
end

This seems to work fine when I run the following:

function mysum!(A, B, C)
    A .= B .+ C
end

xmin        =  -5.0
xmax        =   5.0
ymin        = -10.0
ymax        =  10.0
xnodes      =   64
ynodes      =  128

xx  = range(xmin, xmax, length=xnodes)
yy  = range(ymin, ymax, length=ynodes)

f0   = [ exp( -0.5 * xi^2 ) * exp( -0.5 * yi^2 )
         for xi in xx, yi in yy]
f1   = 2 * f0
f2   = similar(f0)

g0   = MyArr(f0, "bla")
g1   = 2 * g0
g2   = similar(g0)

@btime mysum!($f2, $f0, $f1);
@btime mysum!($g2, $g0, $g1);
@btime mysum!($f2, $g0, $g1);

which returns

julia> include("test_arrays.jl");
  1.906 μs (0 allocations: 0 bytes)
  1.139 μs (0 allocations: 0 bytes)
  1.112 μs (0 allocations: 0 bytes)

However, I run into trouble with DifferentialEquations:

function rhs_test(df, f, p, t)
    df .= f
    nothing
end

tspan = (0.0, 1.0)

prob  = ODEProblem(rhs_test, f0, tspan)
sol   = solve(prob, RK4())
@btime solve(prob, RK4())

# MyArr version
prob2  = ODEProblem(rhs_test, g0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4())

which returns an error:

julia> include("test_ode.jl")                                                                                                                               
  977.501 μs (186 allocations: 4.01 MiB)                                                                                                                    
ERROR: LoadError: BoundsError: attempt to access ()                                                                                                           at index [1]                                                                                                                                              Stacktrace:                                                                                                                                                 
 [1] getindex(::Tuple, ::Int64) at ./tuple.jl:24                                                                                                             [2] find_arr(::Tuple{}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:65                                                                          
 [3] find_arr(::Int64, ::Tuple{}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:68                                                                  
 [4] find_arr(::Tuple{Int64}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:65                                                                      
 [5] find_arr(::Float64, ::Tuple{Int64}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:68                                                           
 [6] find_arr(::Tuple{Float64,Int64}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:65                                                              
 [7] find_arr(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(/),Tuple{Float64,Int64}}) at /home/mzilhao/dev/julia/test_arra
ys/test_arrays.jl:64     

I guess I haven’t properly implemented the copyto! functions (which is the only piece I didn’t copy from the docs)? What am I missing?

Those don’t seem to take into account that broadcast expressions can have more than 1 value?

Oh, what would be the proper way of implementing those, then? I couldn’t find an explicit example for the copyto! functions in the docs, and I’m afraid I don’t completely understand how broadcast expressions work… Is there a similar example you could point me to?

For your case, is there a reason you cannot use DEDataArray? That will already have the broadcast overloads.

I guess I could, yes. I didn’t know about DEDataArray. For some reason, though, it seems that they perform slightly slower than standard arrays:

mutable struct MyDataArray{T,N} <: DEDataArray{T,N}
    x    :: Array{T,N}
    name :: String
end

xmin        =  -5.0
xmax        =   5.0
ymin        = -10.0
ymax        =  10.0
xnodes      =   64
ynodes      =  128

xx  = range(xmin, xmax, length=xnodes)
yy  = range(ymin, ymax, length=ynodes)

f0   = [ exp( -0.5 * xi^2 ) * exp( -0.5 * yi^2 )
         for xi in xx, yi in yy]
f1   = 2 * f0
f2   = similar(f0)

h0   = MyDataArray(f0, "bla")
h1   = MyDataArray(f1, "bla")
h2   = similar(h0)

function rhs_test(df, f, p, t)
    df .= f
    nothing
end

tspan = (0.0, 1.0)

prob  = ODEProblem(rhs_test, f0, tspan)
sol   = solve(prob, RK4())
@btime solve(prob, RK4())

prob2  = ODEProblem(rhs_test, h0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4())

returning:

julia> include("test_ode.jl")
  1.125 ms (182 allocations: 4.01 MiB)
  1.353 ms (362 allocations: 4.03 MiB)

Any idea why?

Also, for curiosity’s sake, is there anywhere I could have a look as to how to properly implement broadcast overloads for the copyto! functions?

Because they deepcopy more information than an array than saving, since they have more information than an array when saving.