# Custom structs in JuliaDiffEq: In-place broadcasting results in allocation and slowdown

I am trying to write a program that solves a set of partial differential equations. This includes one space dimension and one time dimension, so I would like to pass a spatial vector for each of my variables to the ODE solver. For this I defined a custom struct to hold all of the information. I also would like my my program to be fast, so I studied up on the general methods to improve performance in DiffEq (Optimizing DiffEq Code), one of which is to reduce allocation by using an in-place form for the function at calculates you derivatives. By using the in-place form, the solver (presumably) uses broadcasting behind the scenes to reduce allocation. I did this, but my performance was way worse, and allocations increased. I tore my hair out for a long time until I realized that the problem stems from how broadcasting is achieved.

Here is a minimal working example that demonstrates this. I define my struct, and then define all of the interfaces as defined in (Interfaces · The Julia Language) for DiffEq. I then benchmark in-place broadcasting of my custom struct:

``````module temp

using BenchmarkTools

struct vars{T} <: AbstractArray{T,2}
vec1::Vector{T}
vec2::Vector{T}
vec3::Vector{T}
vec4::Vector{T}
vec5::Vector{T}
vec6::Vector{T}
vec7::Vector{T}
vec8::Vector{T}
vec9::Vector{T}
vec10::Vector{T}
vec11::Vector{T}
vec12::Vector{T}
end

cont(x::vars) = (x.vec1,x.vec2,x.vec3,x.vec4,x.vec5,x.vec6,x.vec7,x.vec8,x.vec9,x.vec10,x.vec11,x.vec12)
numvar=12

# Iteration
Base.IteratorSize(::Type{<:vars}) = Iterators.HasShape{2}()
Base.eltype(::Type{vars{T}}) where T = T
Base.isempty(x::vars) = isempty(x.vec1)
function Base.iterate(x::vars, state...)
return iterate(Iterators.flatten(cont(x)), state...)
end
Base.size(x::vars) = (length(x.vec1), numvar)
Base.size(x::vars, d) = size(x)[d]

# Indexing
function lin2cart(x::vars, i::Number)
n = length(x.vec1)
return (i - 1) % n + 1, (i - 1) ÷ n + 1
end

Base.getindex(x::vars, i) = getindex(x, i.I...)
Base.getindex(x::vars, i::Number) = getindex(x, lin2cart(x, i)...)
Base.getindex(x::vars, i, j) = getindex(cont(x)[j], i)
Base.setindex!(x::vars, v, i) = setindex!(x, v, i.I...)
Base.setindex!(x::vars, v, i::Number) = setindex!(x, v, lin2cart(x, i))
Base.setindex!(x::vars, v, i, j) = setindex!(cont(x)[j], v, i)

# Abstract Array
Base.IndexStyle(::vars) = IndexCartesian()
Base.similar(x::vars) = vars(map(similar, cont(x))...)
function Base.similar(x::vars, ::Type{T}) where {T}
return vars(map(y -> similar(y,T), cont(x))...)
end
Base.similar(x::vars, ::Dims) = similar(x)
Base.similar(x::vars, ::Dims, ::Type{T}) where {T} = similar(x, T)

::Type{T}) where {T}
return similar(find_vars(bc), T)
end
find_vars(args::Tuple) = find_vars(find_vars(args), Base.tail(args))
find_vars(x) = x
find_vars(::Tuple{}) = nothing
find_vars(a::vars, rest) = a
find_vars(::Any, rest) = find_vars(rest)

function Base.similar(::Type{vars{T}},size::Int) where T
return vars{T}(
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size)),
similar(Vector{T}(undef,size))
)
end

function main()

T = Float64

n = 4000

vars1 = similar(vars{T},n)
vars2 = similar(vars{T},n)

@btime \$vars1.vec1 .= \$vars2.vec2

@btime \$vars1 .= \$vars2

return

end

end

``````

This results in

``````julia> temp.main()
2.189 μs (0 allocations: 0 bytes)
40.746 ms (288040 allocations: 16.11 MiB)
``````

The broadcasting here should not allocate at all, and certainly shouldn’t take milliseconds to accomplish (I wouldn’t think). If I decrease the number of variables in my struct to 3 and set n=1000 I get instead

``````julia> temp.main()
1.300 μs (0 allocations: 0 bytes)
3.975 μs (0 allocations: 0 bytes)
``````

which is the intended outcome, so somehow the number of variables or the number of bytes to move around plays a role. The end goal here is to pass this struct to an ODESolver, so I need some way defining Broadcasting so that I get the desired outcome, but I do not understand the custom broadcasting outlined in the docs, as all I have done here is essentially copy and paste the example from the docs for defining the broadcasting interface for the custom struct.

1 Like

The indexing is slow and causes this. Does `VectorOfArray` do this better? `ArrayPartition` from RecursiveArrayTools.jl most likely does.

1 Like

Thanks Chris, this works way better. The DiffEq documentation should really encourage this, they spend a lot of time explaining how to treat custom structures when using package defined structures like `ArrayPartition` has vastly simplified my code.

Another thing that helped me is that even though the `ArrayPartition` documentation claims that if `A` and `B` are both `ArrayPartition`'s with the same properties,

``````A .= B
``````

should be efficient, this is still very slow, and I got a huge speedup by just looping this

``````for i in 1:length(A)
A[i] .= B[i]
end
``````

The slowdown is by more than a factor of 100, and as far as I can tell my code is properly type stable. I have no idea if this issue is caused by `DifferentialEquations` or `RecursiveArrayTools`, but I am happy with the speed of my code now.

Share an MWE? Indexing should be slower . `copyto!(A,B)` should be fast unless it’s missing an overload.

1 Like

Alright, I cut down my program to the following

``````module temp

using DifferentialEquations
using BoundaryValueDiffEq
using OrdinaryDiffEq
using RecursiveArrayTools

using BenchmarkTools
using InteractiveUtils
using RecursiveArrayTools

struct Param{S}
state::S
dtstate2::S
end

@inline function Base.zeros(::Type{ArrayPartition},::Type{T},n::Int) where T
return ArrayPartition(
zeros(n),zeros(n),zeros(n),zeros(n),zeros(n),zeros(n),
zeros(n),zeros(n),zeros(n),zeros(n),zeros(n),zeros(n)
)
end

function rhs!(dtstate::S, regstate::S, param::Param{S}, t) where S

#Unpack the parameters

dtstate2 = param.dtstate2
state = param.state

## Store the current state in the container
## so it can be modified during the calculations
# for i in 1:12
#     state.x[i] .= regstate.x[i]
# end
state .= regstate

# Unpack the states so I can refer to individual things
a,b,c,d,e,f,g,h,i,j,k,l = state.x

A,B,C,D,E,F,G,H,I,J,K,L = dtstate.x

## Do all your calculations here....
## here is an example

@. A = a + b + c + d + e + f + g + h + i + j + k + l
@. B = a + b + c + d + e + f + g + h + i + j + k + l
@. C = a + b + c + d + e + f + g + h + i + j + k + l
@. D = a + b + c + d + e + f + g + h + i + j + k + l
@. E = a + b + c + d + e + f + g + h + i + j + k + l
@. F = a + b + c + d + e + f + g + h + i + j + k + l
@. G = a + b + c + d + e + f + g + h + i + j + k + l
@. H = a + b + c + d + e + f + g + h + i + j + k + l
@. I = a + b + c + d + e + f + g + h + i + j + k + l
@. J = a + b + c + d + e + f + g + h + i + j + k + l
@. K = a + b + c + d + e + f + g + h + i + j + k + l
@. L = a + b + c + d + e + f + g + h + i + j + k + l

## Store the current time derivatives in the container
# for use elsewhere
# for i in 1:12
#     dtstate2.x[i] .= dtstate.x[i]
# end
dtstate2 .= dtstate

end

function main()

n = 4000

numvar = 12

T = Float64

tspan = T[0.,1.]

cont = ArrayPartition{T, NTuple{numvar, Vector{T}}}

regstate = zeros(ArrayPartition,T,n)::cont
state = zeros(ArrayPartition,T,n)::cont
dtstate = zeros(ArrayPartition,T,n)::cont

param = Param(state,dtstate)

prob = ODEProblem(rhs!, regstate, tspan, param)

atol = eps(T)^(T(3) / 4)

alg = RK4()

# @btime sol = solve(
#     \$prob, \$alg,
#     abstol = \$atol,
#     dt = 0.1,
#     saveat = 0.1,
#     alias_u0 = true
#     # progress = true,
#     # progress_steps = custom_progress_step,
#     # progress_message = custom_progress_message
# )

@code_warntype rhs!(dtstate,regstate,param,0.)

return

end

end
``````

The slowdown here between using

``````for i in 1:12
state.x[i] .= regstate.x[i]
end
``````

and

``````state .= regstate
``````

is not as dramatic here, but I have found the problem has to do with some sort of typing issue. If you use `state .= regstate` and run `@code_warntype` on `rhs!` I get the following

``````Body::Any
1 ─        (dtstate2 = Base.getproperty(param, :dtstate2))
│          (state = Base.getproperty(param, :state))
│   %3   = state::RecursiveArrayTools.ArrayPartition{Float64, NTuple{12, Vector{Float64}}}
``````

the line at `%4` being red. It seems to accept that `state` and `regstate` are the same type though, so I don’t understand this. Switching to the looped version leads to blue skies in the `@code_warntype` output.

I’m fairly new to Julia, so this could well be some trivial problem.

1 Like

Interesting, open an issue in RecursiveArrayTools.jl. The broadcast overload should be equivalent to splitting and doing:

``````for i in 1:12
copyto!(state.x[i],regstate.x[i])
end
``````

which should be fully inferred in your case. In reality, it should reduce it so that it avoids potential instabilities in the tuple, i.e.:

``````copyto!(state.x,regstate.x)
copyto!(state.x,regstate.x)
...
``````

(since tuples can be heterogeneously typed). If it’s not doing that, that’s the issue.

1 Like

Do you think this issue is related? I ask since you replied to it.

That issue isn’t related. The solvers were changed so the reply there no longer applies anyways.