JuliaDiffEq with custom types

Sorry for resurrecting this old thread, but I noticed that my custom array type was performing rather badly (when compared with standard arrays) with DifferentialEquations. I tried to define the broadcasting operations following the docs, but there’s something I’m not doing right… Here’s what I have:

struct MyArr{A,N} <: AbstractArray{A,N}
    data  :: Array{A,N}
    name  :: String
end
MyArr(data::AbstractArray{T,N}, name::String) where {T,N} =
    MyArr{eltype(data),N}(data, name)

@inline Base.size(arr::MyArr) = size(arr.data)
@inline Base.length(arr::MyArr) = length(arr.data)

@inline Base.getindex(arr::MyArr, i::Int) = arr.data[i]
@inline Base.getindex(arr::MyArr, I::Vararg{Int,N}) where {N} = arr.data[I...]
@inline Base.getindex(arr::MyArr, ::Colon) = arr.data[:]
@inline Base.getindex(arr::MyArr, kr::AbstractRange) = arr.data[kr]

@inline Base.setindex!(arr::MyArr, v, i::Int) = (arr.data[i] = v)
@inline Base.setindex!(arr::MyArr, v, I::Vararg{Int,N}) where {N} = (arr.data[I...] = v)
@inline Base.setindex!(arr::MyArr, v, ::Colon) = (arr.data[:] .= v)
@inline Base.setindex!(arr::MyArr, v, kr::AbstractRange) = (arr.data[kr] .= v)

@inline Base.iterate(arr::MyArr) = iterate(arr.data)
@inline Base.iterate(arr::MyArr, state) = iterate(arr.data, state)

Base.show(io::IO, arr::MyArr) = show(io, arr.data)

Base.IndexStyle(arr::MyArr) = Base.IndexStyle(arr.data)

function Base.similar(arr::MyArr)
    data   = similar(arr.data)
    MyArr(data, arr.name)
end
function Base.similar(arr::MyArr, ::Type{T}) where {T}
    data   = similar(arr.data)
    MyArr(similar(data,T),arr.name)
end

Base.:*(x::Number, f::MyArr) = MyArr(x*f.data, f.name)
Base.:*(f::MyArr, x::Number) = MyArr(x*f.data, f.name)
Base.:/(f::MyArr, x::Number) = MyArr(f.data/x, f.name)
Base.:\(x::Number, f::MyArr) = MyArr(f.data/x, f.name)

# let's take the left name. they should match, but i'm not sure if checking
# for that is costly
Base.:+(f::MyArr, g::MyArr) = MyArr(f.data+g.data, f.name)
Base.:-(f::MyArr, g::MyArr) = MyArr(f.data-g.data, f.name)

Base.BroadcastStyle(::Type{<:MyArr}) = Broadcast.ArrayStyle{MyArr}()

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArr}},
    ::Type{T}) where T
    # Scan the inputs for MyArr:
    arr = find_arr(bc)
    # Use the name field of MyArr to create the output
    MyArr(similar(Array{T}, axes(bc)), arr.name)
end

find_arr(bc::Base.Broadcast.Broadcasted) = find_arr(bc.args)
find_arr(args::Tuple) = find_arr(find_arr(args[1]), Base.tail(args))
find_arr(x) = x
find_arr(a::MyArr, rest) = a
find_arr(::Any, rest) = find_arr(rest)

@inline function Base.copyto!(dest::Array,
    bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArr}})
    arr = find_arr(bc)
    copyto!(dest, arr.data)
end

@inline function Base.copyto!(dest::MyArr,
    bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArr}})
    arr = find_arr(bc)
    copyto!(dest.data, arr.data)
end

This seems to work fine when I run the following:

function mysum!(A, B, C)
    A .= B .+ C
end

xmin        =  -5.0
xmax        =   5.0
ymin        = -10.0
ymax        =  10.0
xnodes      =   64
ynodes      =  128

xx  = range(xmin, xmax, length=xnodes)
yy  = range(ymin, ymax, length=ynodes)

f0   = [ exp( -0.5 * xi^2 ) * exp( -0.5 * yi^2 )
         for xi in xx, yi in yy]
f1   = 2 * f0
f2   = similar(f0)

g0   = MyArr(f0, "bla")
g1   = 2 * g0
g2   = similar(g0)

@btime mysum!($f2, $f0, $f1);
@btime mysum!($g2, $g0, $g1);
@btime mysum!($f2, $g0, $g1);

which returns

julia> include("test_arrays.jl");
  1.906 μs (0 allocations: 0 bytes)
  1.139 μs (0 allocations: 0 bytes)
  1.112 μs (0 allocations: 0 bytes)

However, I run into trouble with DifferentialEquations:

function rhs_test(df, f, p, t)
    df .= f
    nothing
end

tspan = (0.0, 1.0)

prob  = ODEProblem(rhs_test, f0, tspan)
sol   = solve(prob, RK4())
@btime solve(prob, RK4())

# MyArr version
prob2  = ODEProblem(rhs_test, g0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4())

which returns an error:

julia> include("test_ode.jl")                                                                                                                               
  977.501 μs (186 allocations: 4.01 MiB)                                                                                                                    
ERROR: LoadError: BoundsError: attempt to access ()                                                                                                           at index [1]                                                                                                                                              Stacktrace:                                                                                                                                                 
 [1] getindex(::Tuple, ::Int64) at ./tuple.jl:24                                                                                                             [2] find_arr(::Tuple{}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:65                                                                          
 [3] find_arr(::Int64, ::Tuple{}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:68                                                                  
 [4] find_arr(::Tuple{Int64}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:65                                                                      
 [5] find_arr(::Float64, ::Tuple{Int64}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:68                                                           
 [6] find_arr(::Tuple{Float64,Int64}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:65                                                              
 [7] find_arr(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(/),Tuple{Float64,Int64}}) at /home/mzilhao/dev/julia/test_arra
ys/test_arrays.jl:64     

I guess I haven’t properly implemented the copyto! functions (which is the only piece I didn’t copy from the docs)? What am I missing?