JuliaDiffEq with custom types

Nice! I didn’t know about the Val variant of ntuple. This is the fastest so far (along with unrolled tuple construction). This gets it to 1.2x of StaticArray and to 1.4x without those special-case _copy methods.

Interesting! I thought it would be just equivalent to @set!.

Yep, ntuple seems a little faster. For larger input vectors the difference increases. I think it is many due to the fact that ntuple unrolls the loop. If I manually unroll the Setfield loop, it’s the same speed.

1 Like

Sorry for resurrecting this old thread, but I noticed that my custom array type was performing rather badly (when compared with standard arrays) with DifferentialEquations. I tried to define the broadcasting operations following the docs, but there’s something I’m not doing right… Here’s what I have:

struct MyArr{A,N} <: AbstractArray{A,N}
    data  :: Array{A,N}
    name  :: String
end
MyArr(data::AbstractArray{T,N}, name::String) where {T,N} =
    MyArr{eltype(data),N}(data, name)

@inline Base.size(arr::MyArr) = size(arr.data)
@inline Base.length(arr::MyArr) = length(arr.data)

@inline Base.getindex(arr::MyArr, i::Int) = arr.data[i]
@inline Base.getindex(arr::MyArr, I::Vararg{Int,N}) where {N} = arr.data[I...]
@inline Base.getindex(arr::MyArr, ::Colon) = arr.data[:]
@inline Base.getindex(arr::MyArr, kr::AbstractRange) = arr.data[kr]

@inline Base.setindex!(arr::MyArr, v, i::Int) = (arr.data[i] = v)
@inline Base.setindex!(arr::MyArr, v, I::Vararg{Int,N}) where {N} = (arr.data[I...] = v)
@inline Base.setindex!(arr::MyArr, v, ::Colon) = (arr.data[:] .= v)
@inline Base.setindex!(arr::MyArr, v, kr::AbstractRange) = (arr.data[kr] .= v)

@inline Base.iterate(arr::MyArr) = iterate(arr.data)
@inline Base.iterate(arr::MyArr, state) = iterate(arr.data, state)

Base.show(io::IO, arr::MyArr) = show(io, arr.data)

Base.IndexStyle(arr::MyArr) = Base.IndexStyle(arr.data)

function Base.similar(arr::MyArr)
    data   = similar(arr.data)
    MyArr(data, arr.name)
end
function Base.similar(arr::MyArr, ::Type{T}) where {T}
    data   = similar(arr.data)
    MyArr(similar(data,T),arr.name)
end

Base.:*(x::Number, f::MyArr) = MyArr(x*f.data, f.name)
Base.:*(f::MyArr, x::Number) = MyArr(x*f.data, f.name)
Base.:/(f::MyArr, x::Number) = MyArr(f.data/x, f.name)
Base.:\(x::Number, f::MyArr) = MyArr(f.data/x, f.name)

# let's take the left name. they should match, but i'm not sure if checking
# for that is costly
Base.:+(f::MyArr, g::MyArr) = MyArr(f.data+g.data, f.name)
Base.:-(f::MyArr, g::MyArr) = MyArr(f.data-g.data, f.name)

Base.BroadcastStyle(::Type{<:MyArr}) = Broadcast.ArrayStyle{MyArr}()

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArr}},
    ::Type{T}) where T
    # Scan the inputs for MyArr:
    arr = find_arr(bc)
    # Use the name field of MyArr to create the output
    MyArr(similar(Array{T}, axes(bc)), arr.name)
end

find_arr(bc::Base.Broadcast.Broadcasted) = find_arr(bc.args)
find_arr(args::Tuple) = find_arr(find_arr(args[1]), Base.tail(args))
find_arr(x) = x
find_arr(a::MyArr, rest) = a
find_arr(::Any, rest) = find_arr(rest)

@inline function Base.copyto!(dest::Array,
    bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArr}})
    arr = find_arr(bc)
    copyto!(dest, arr.data)
end

@inline function Base.copyto!(dest::MyArr,
    bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArr}})
    arr = find_arr(bc)
    copyto!(dest.data, arr.data)
end

This seems to work fine when I run the following:

function mysum!(A, B, C)
    A .= B .+ C
end

xmin        =  -5.0
xmax        =   5.0
ymin        = -10.0
ymax        =  10.0
xnodes      =   64
ynodes      =  128

xx  = range(xmin, xmax, length=xnodes)
yy  = range(ymin, ymax, length=ynodes)

f0   = [ exp( -0.5 * xi^2 ) * exp( -0.5 * yi^2 )
         for xi in xx, yi in yy]
f1   = 2 * f0
f2   = similar(f0)

g0   = MyArr(f0, "bla")
g1   = 2 * g0
g2   = similar(g0)

@btime mysum!($f2, $f0, $f1);
@btime mysum!($g2, $g0, $g1);
@btime mysum!($f2, $g0, $g1);

which returns

julia> include("test_arrays.jl");
  1.906 μs (0 allocations: 0 bytes)
  1.139 μs (0 allocations: 0 bytes)
  1.112 μs (0 allocations: 0 bytes)

However, I run into trouble with DifferentialEquations:

function rhs_test(df, f, p, t)
    df .= f
    nothing
end

tspan = (0.0, 1.0)

prob  = ODEProblem(rhs_test, f0, tspan)
sol   = solve(prob, RK4())
@btime solve(prob, RK4())

# MyArr version
prob2  = ODEProblem(rhs_test, g0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4())

which returns an error:

julia> include("test_ode.jl")                                                                                                                               
  977.501 μs (186 allocations: 4.01 MiB)                                                                                                                    
ERROR: LoadError: BoundsError: attempt to access ()                                                                                                           at index [1]                                                                                                                                              Stacktrace:                                                                                                                                                 
 [1] getindex(::Tuple, ::Int64) at ./tuple.jl:24                                                                                                             [2] find_arr(::Tuple{}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:65                                                                          
 [3] find_arr(::Int64, ::Tuple{}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:68                                                                  
 [4] find_arr(::Tuple{Int64}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:65                                                                      
 [5] find_arr(::Float64, ::Tuple{Int64}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:68                                                           
 [6] find_arr(::Tuple{Float64,Int64}) at /home/mzilhao/dev/julia/test_arrays/test_arrays.jl:65                                                              
 [7] find_arr(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(/),Tuple{Float64,Int64}}) at /home/mzilhao/dev/julia/test_arra
ys/test_arrays.jl:64     

I guess I haven’t properly implemented the copyto! functions (which is the only piece I didn’t copy from the docs)? What am I missing?

Those don’t seem to take into account that broadcast expressions can have more than 1 value?

Oh, what would be the proper way of implementing those, then? I couldn’t find an explicit example for the copyto! functions in the docs, and I’m afraid I don’t completely understand how broadcast expressions work… Is there a similar example you could point me to?

For your case, is there a reason you cannot use DEDataArray? That will already have the broadcast overloads.

I guess I could, yes. I didn’t know about DEDataArray. For some reason, though, it seems that they perform slightly slower than standard arrays:

mutable struct MyDataArray{T,N} <: DEDataArray{T,N}
    x    :: Array{T,N}
    name :: String
end

xmin        =  -5.0
xmax        =   5.0
ymin        = -10.0
ymax        =  10.0
xnodes      =   64
ynodes      =  128

xx  = range(xmin, xmax, length=xnodes)
yy  = range(ymin, ymax, length=ynodes)

f0   = [ exp( -0.5 * xi^2 ) * exp( -0.5 * yi^2 )
         for xi in xx, yi in yy]
f1   = 2 * f0
f2   = similar(f0)

h0   = MyDataArray(f0, "bla")
h1   = MyDataArray(f1, "bla")
h2   = similar(h0)

function rhs_test(df, f, p, t)
    df .= f
    nothing
end

tspan = (0.0, 1.0)

prob  = ODEProblem(rhs_test, f0, tspan)
sol   = solve(prob, RK4())
@btime solve(prob, RK4())

prob2  = ODEProblem(rhs_test, h0, tspan)
sol2   = solve(prob2, RK4())
@btime solve(prob2, RK4())

returning:

julia> include("test_ode.jl")
  1.125 ms (182 allocations: 4.01 MiB)
  1.353 ms (362 allocations: 4.03 MiB)

Any idea why?

Also, for curiosity’s sake, is there anywhere I could have a look as to how to properly implement broadcast overloads for the copyto! functions?

Because they deepcopy more information than an array than saving, since they have more information than an array when saving.