Custom structs in JuliaDiffEq: In-place broadcasting results in allocation and slowdown

I am trying to write a program that solves a set of partial differential equations. This includes one space dimension and one time dimension, so I would like to pass a spatial vector for each of my variables to the ODE solver. For this I defined a custom struct to hold all of the information. I also would like my my program to be fast, so I studied up on the general methods to improve performance in DiffEq (Optimizing DiffEq Code), one of which is to reduce allocation by using an in-place form for the function at calculates you derivatives. By using the in-place form, the solver (presumably) uses broadcasting behind the scenes to reduce allocation. I did this, but my performance was way worse, and allocations increased. I tore my hair out for a long time until I realized that the problem stems from how broadcasting is achieved.

Here is a minimal working example that demonstrates this. I define my struct, and then define all of the interfaces as defined in (Interfaces · The Julia Language) for DiffEq. I then benchmark in-place broadcasting of my custom struct:

module temp

using BenchmarkTools

struct vars{T} <: AbstractArray{T,2}
    vec1::Vector{T}
    vec2::Vector{T}
    vec3::Vector{T}
    vec4::Vector{T}
    vec5::Vector{T}
    vec6::Vector{T}
    vec7::Vector{T}
    vec8::Vector{T}
    vec9::Vector{T}
    vec10::Vector{T}
    vec11::Vector{T}
    vec12::Vector{T}
end

cont(x::vars) = (x.vec1,x.vec2,x.vec3,x.vec4,x.vec5,x.vec6,x.vec7,x.vec8,x.vec9,x.vec10,x.vec11,x.vec12)
numvar=12

# Iteration
Base.IteratorSize(::Type{<:vars}) = Iterators.HasShape{2}()
Base.eltype(::Type{vars{T}}) where T = T
Base.isempty(x::vars) = isempty(x.vec1)
function Base.iterate(x::vars, state...)
    return iterate(Iterators.flatten(cont(x)), state...)
end
Base.size(x::vars) = (length(x.vec1), numvar)
Base.size(x::vars, d) = size(x)[d]

# Indexing
function lin2cart(x::vars, i::Number)
    n = length(x.vec1)
    return (i - 1) % n + 1, (i - 1) ÷ n + 1
end

Base.getindex(x::vars, i) = getindex(x, i.I...)
Base.getindex(x::vars, i::Number) = getindex(x, lin2cart(x, i)...)
Base.getindex(x::vars, i, j) = getindex(cont(x)[j], i)
Base.setindex!(x::vars, v, i) = setindex!(x, v, i.I...)
Base.setindex!(x::vars, v, i::Number) = setindex!(x, v, lin2cart(x, i))
Base.setindex!(x::vars, v, i, j) = setindex!(cont(x)[j], v, i)

# Abstract Array
Base.IndexStyle(::vars) = IndexCartesian()
Base.similar(x::vars) = vars(map(similar, cont(x))...)
function Base.similar(x::vars, ::Type{T}) where {T}
    return vars(map(y -> similar(y,T), cont(x))...)
end
Base.similar(x::vars, ::Dims) = similar(x)
Base.similar(x::vars, ::Dims, ::Type{T}) where {T} = similar(x, T)

#Broadcasting
Base.BroadcastStyle(::Type{<:vars}) = Broadcast.ArrayStyle{vars}()
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{vars}},
                      ::Type{T}) where {T}
    return similar(find_vars(bc), T)
end
find_vars(bc::Base.Broadcast.Broadcasted) = find_vars(bc.args)
find_vars(args::Tuple) = find_vars(find_vars(args[1]), Base.tail(args))
find_vars(x) = x
find_vars(::Tuple{}) = nothing
find_vars(a::vars, rest) = a
find_vars(::Any, rest) = find_vars(rest)

function Base.similar(::Type{vars{T}},size::Int) where T
    return vars{T}(
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size)),
        similar(Vector{T}(undef,size))
    )
end

function main()

    T = Float64

    n = 4000

    vars1 = similar(vars{T},n)
    vars2 = similar(vars{T},n)

    @btime $vars1.vec1 .= $vars2.vec2

    @btime $vars1 .= $vars2

    return

end

end

This results in

julia> temp.main()
  2.189 μs (0 allocations: 0 bytes)
  40.746 ms (288040 allocations: 16.11 MiB)

The broadcasting here should not allocate at all, and certainly shouldn’t take milliseconds to accomplish (I wouldn’t think). If I decrease the number of variables in my struct to 3 and set n=1000 I get instead

julia> temp.main()
  1.300 μs (0 allocations: 0 bytes)
  3.975 μs (0 allocations: 0 bytes)

which is the intended outcome, so somehow the number of variables or the number of bytes to move around plays a role. The end goal here is to pass this struct to an ODESolver, so I need some way defining Broadcasting so that I get the desired outcome, but I do not understand the custom broadcasting outlined in the docs, as all I have done here is essentially copy and paste the example from the docs for defining the broadcasting interface for the custom struct.

Thanks in advance.

1 Like

The indexing is slow and causes this. Does VectorOfArray do this better? ArrayPartition from RecursiveArrayTools.jl most likely does.

1 Like

Thanks Chris, this works way better. The DiffEq documentation should really encourage this, they spend a lot of time explaining how to treat custom structures when using package defined structures like ArrayPartition has vastly simplified my code.

Another thing that helped me is that even though the ArrayPartition documentation claims that if A and B are both ArrayPartition's with the same properties,

A .= B

should be efficient, this is still very slow, and I got a huge speedup by just looping this

for i in 1:length(A)
A[i] .= B[i]
end

The slowdown is by more than a factor of 100, and as far as I can tell my code is properly type stable. I have no idea if this issue is caused by DifferentialEquations or RecursiveArrayTools, but I am happy with the speed of my code now.

Share an MWE? Indexing should be slower :sweat_smile: . copyto!(A,B) should be fast unless it’s missing an overload.

1 Like

Alright, I cut down my program to the following

module temp

using DifferentialEquations
using BoundaryValueDiffEq
using OrdinaryDiffEq
using RecursiveArrayTools

using BenchmarkTools
using InteractiveUtils
using RecursiveArrayTools

struct Param{S}
    state::S
    dtstate2::S
end

@inline function Base.zeros(::Type{ArrayPartition},::Type{T},n::Int) where T
    return ArrayPartition(
        zeros(n),zeros(n),zeros(n),zeros(n),zeros(n),zeros(n),
        zeros(n),zeros(n),zeros(n),zeros(n),zeros(n),zeros(n)
    )
end

function rhs!(dtstate::S, regstate::S, param::Param{S}, t) where S

    #Unpack the parameters

    dtstate2 = param.dtstate2
    state = param.state

    ## Store the current state in the container
    ## so it can be modified during the calculations
    # for i in 1:12
    #     state.x[i] .= regstate.x[i]
    # end
    state .= regstate

    # Unpack the states so I can refer to individual things
    a,b,c,d,e,f,g,h,i,j,k,l = state.x

    A,B,C,D,E,F,G,H,I,J,K,L = dtstate.x

    ## Do all your calculations here....
    ## here is an example

    @. A = a + b + c + d + e + f + g + h + i + j + k + l
    @. B = a + b + c + d + e + f + g + h + i + j + k + l
    @. C = a + b + c + d + e + f + g + h + i + j + k + l
    @. D = a + b + c + d + e + f + g + h + i + j + k + l
    @. E = a + b + c + d + e + f + g + h + i + j + k + l
    @. F = a + b + c + d + e + f + g + h + i + j + k + l
    @. G = a + b + c + d + e + f + g + h + i + j + k + l
    @. H = a + b + c + d + e + f + g + h + i + j + k + l
    @. I = a + b + c + d + e + f + g + h + i + j + k + l
    @. J = a + b + c + d + e + f + g + h + i + j + k + l
    @. K = a + b + c + d + e + f + g + h + i + j + k + l
    @. L = a + b + c + d + e + f + g + h + i + j + k + l


    ## Store the current time derivatives in the container
    # for use elsewhere
    # for i in 1:12
    #     dtstate2.x[i] .= dtstate.x[i]
    # end
    dtstate2 .= dtstate



end

function main()

    n = 4000

    numvar = 12

    T = Float64

    tspan = T[0.,1.]

    cont = ArrayPartition{T, NTuple{numvar, Vector{T}}}

    regstate = zeros(ArrayPartition,T,n)::cont
    state = zeros(ArrayPartition,T,n)::cont
    dtstate = zeros(ArrayPartition,T,n)::cont

    param = Param(state,dtstate)

    prob = ODEProblem(rhs!, regstate, tspan, param)

    atol = eps(T)^(T(3) / 4)

    alg = RK4()

    # @btime sol = solve(
    #     $prob, $alg,
    #     abstol = $atol,
    #     dt = 0.1,
    #     adaptive = false,
    #     saveat = 0.1,
    #     alias_u0 = true
    #     # progress = true,
    #     # progress_steps = custom_progress_step,
    #     # progress_message = custom_progress_message
    # )

    @code_warntype rhs!(dtstate,regstate,param,0.)

    return

end

end

The slowdown here between using

for i in 1:12
    state.x[i] .= regstate.x[i]
end

and

state .= regstate

is not as dramatic here, but I have found the problem has to do with some sort of typing issue. If you use state .= regstate and run @code_warntype on rhs! I get the following

Body::Any
1 ─        (dtstate2 = Base.getproperty(param, :dtstate2))
│          (state = Base.getproperty(param, :state))
│   %3   = state::RecursiveArrayTools.ArrayPartition{Float64, NTuple{12, Vector{Float64}}}
│   %4   = Base.broadcasted(Base.identity, regstate)::Base.Broadcast.Broadcasted{_A, Nothing, typeof(identity), Tuple{RecursiveArrayTools.ArrayPartition{Float64, NTuple{12, Vector{Float64}}}}} where _A<:Union{Nothing, Base.Broadcast.BroadcastStyle}

the line at %4 being red. It seems to accept that state and regstate are the same type though, so I don’t understand this. Switching to the looped version leads to blue skies in the @code_warntype output.

I’m fairly new to Julia, so this could well be some trivial problem.

1 Like

Interesting, open an issue in RecursiveArrayTools.jl. The broadcast overload should be equivalent to splitting and doing:

for i in 1:12
    copyto!(state.x[i],regstate.x[i])
end

which should be fully inferred in your case. In reality, it should reduce it so that it avoids potential instabilities in the tuple, i.e.:

copyto!(state.x[1],regstate.x[1])
copyto!(state.x[2],regstate.x[2])
...

(since tuples can be heterogeneously typed). If it’s not doing that, :sweat_smile: that’s the issue.

1 Like

Do you think this issue is related? I ask since you replied to it.

That issue isn’t related. The solvers were changed so the reply there no longer applies anyways.