I have a function which operates on tuples of arrays using broadcasting, but I’m rewriting it because it doesn’t play well with ForwardDiff.jl. However, I’ve noticed that version which loops through array entries is slower and produces more allocations.
I’ve included what I hope is a minimal enough working example. I take tuples of arrays X and Y and sum them in bcast and loop.
x = ntuple(x->randn(2,4),2)
y = ntuple(x->randn(2,4),2)
function bcast(x,y)
fsum(x,y) = x + y
out = fsum.(x,y)
return out
end
function loop(x,y)
out = ntuple(a->zeros(size(x[1])),length(x))
for i = 1:length(x[1])
xi = (x->x[i]).(x)
yi = (x->x[i]).(y)
fsum!(x,y,out) = out[i] = x + y
fsum!.(xi,yi,out)
end
return out
end
The extra allocations in the looped function are from fsum!.(f,g,out), but I’m having trouble figuring out why. I tried @code_warntype, but I haven’t been able to interpret what’s going on.
I think the problem is that you’re looping in a weird order. You have an outer container (tuple) and an inner container (array) and you’re mixing the order in which data is accessed in an unnatural way. Julia is sometimes good about fixing these sorts of data-access problems for you, but not always and it’s hard (for me) to predict when it will and won’t fix it. The best practice is just to not rely on julia fixing things at an algorithmic level.
Here’s some looping code that respects the data layout and I think is much easier to understand than the original:
julia> function loop2(x,y)
out = ntuple(i->similar(x[i]), length(x))
for i ∈ eachindex(x, y), j ∈ eachindex(x[i], y[i])
out[i][j] = x[i][j] + y[i][j]
end
return out
end
loop2 (generic function with 1 method)
julia> @btime bcast($x,$y);
134.275 ns (3 allocations: 320 bytes)
julia> @btime loop($x,$y);
2.327 μs (52 allocations: 1.83 KiB)
julia> @btime loop2($x,$y);
113.620 ns (4 allocations: 336 bytes)
As you can see, it slightly outperforms the broadcast code.
We can shave off another 15 nanoseconds by being more careful about when size-checks happen:
julia> function loop3(x,y)
@assert all(size.(x) == size.(y)) && eltype(x) === eltype(y)
out = ntuple(i->similar(x[i]), length(x))
@inbounds for i ∈ eachindex(x, y), j ∈ eachindex(x[i], y[i])
out[i][j] = x[i][j] + y[i][j]
end
return out
end
loop3 (generic function with 1 method)
julia> @btime loop3($x,$y);
98.201 ns (4 allocations: 336 bytes)
If I understand your loop2 (and loop3) examples, you loop first through each element of the tuple, then each element of the array? If so, I think I made the MWE a little too minimal - my actual code calls a function whose inputs are tuples, which constrains the data access pattern.
A less minimal working example is:
x = ntuple(x->randn(2,4),2)
y = ntuple(x->randn(2,4),2)
function foo(x,y)
f1 = @. x[1]*y[1]
f2 = @. x[1]*y[2]
g1 = @. x[2]*y[1]
g2 = @. x[2]*y[2]
return (f1,f2),(g1,g2)
end
function bcast(x,y)
f,g = foo(x,y)
fsum(f,g) = f + g
out = fsum.(f,g)
return out
end
function loop(x,y)
out = ntuple(i->similar(x[i]),length(x))
for i = 1:length(x[1])
xi = (x->x[i]).(x)
yi = (x->x[i]).(y)
f,g = foo(xi,yi)
fsum!(f,g,out) = out[i] = f + g
fsum!.(f,g,out)
end
return out
end
The benchmark times are about the same as before.
My goal is to use the loop routine so that I can remove the broadcasting from foo(x,y), but I think this forces me to use a weird data access pattern.
If the broadcast avoids the weird data access pattern, you should be able to do it with a loop as well. You can trivially reproduce broadcast by just expanding every . expression into loops. I think you’re overthinking the translation from broadcast to loops. Specifically, the (x -> x[i]).(x) is incredibly suspicious.
julia> function loop4(x, y)
@assert all(size.(x) == size.(y)) && eltype(x) === eltype(y)
out = ntuple(i->similar(x[i]), length(x))
f, g = foo(x, y)
@inbounds for i ∈ eachindex(f, g)
for j ∈ eachindex(out[i], f[i], g[i])
out[i][j] = f[i][j] + g[i][j]
end
end
out
end
loop4 (generic function with 1 method)
julia> @btime loop($x, $y);
2.831 μs (52 allocations: 1.83 KiB)
julia> @btime loop4($x, $y);
302.781 ns (11 allocations: 1008 bytes)
julia> @btime bcast($x, $y);
343.137 ns (10 allocations: 992 bytes)
However, I would suggest just using broadcast if you can. It’s more flexible, maintainable and general than manually writing loops, without having much overhead. Why do you want to remove the broadcast? You mention something about how forward diff isn’t playing well with broadcast. Frankly, I find that a little hard to believe / understand.
You’re right - this was my mistake. I was seeing errors with ForwardDiff.derivative(x-> @. x^2,x), which I was attributing to bcast. Rewriting it as ForwardDiff.derivative(x-> (@. x^2),x) fixes this.
Thanks! I agree - I much prefer the bcast code, but had convinced myself that I needed to change it b/c of that simple error.