Flux: Locally connected layer performance

I have tried to implement a 2D locally connected layer (convolution with non-shared weights) in Flux restricted to filter size (1,1) only. That is, the linear combinations/convolutions are only over the (3rd) channel/variable dimension. Since this type of layer involves many weight parameters, an efficient implementation with respect to speed and memory management is essential. Is there any scope for improvement of the following code, in particular for GPU?

function (a::Local2D)(x::AbstractArray)  # size(x) = (w,h,cin,n), size(W) = (w,h,cin,cout)
    W, b, σ = a.W, a.b, a.σ
    out = cat([sum(W[:,:,:,i] .* x, dims=3) .+ b[:,:,:,i] for i in 1:size(W,4)]..., dims=3)
    return σ.(out)

The complete code including a test example:

struct Local2D{S<:AbstractArray, T<:AbstractArray, F}

Local2D(W, b) = Local2D(W, b, identity)

function Local2D(w::Integer, h::Integer, cin::Integer, cout::Integer, σ = identity;
                 initW = Flux.glorot_uniform, initb = zeros)
    return Local2D(initW(w,h,cin,cout), initb(Float32, w,h,1,cout), σ)

Flux.@functor Local2D

function (a::Local2D)(x::AbstractArray)  # size(x) = (w,h,cin,n), size(W) = (w,h,cin,cout)
    W, b, σ = a.W, a.b, a.σ
    out = cat([sum(W[:,:,:,i] .* x, dims=3) .+ b[:,:,:,i] for i in 1:size(W,4)]..., dims=3)
    return σ.(out)

#  example
device = gpu
n = 1024; d = 256; cin = 4; cout = 2
x = rand(Float32, d, d, cin, n) |> device;
y = rand(Float32, d, d, cout, n) |> device;
trdata = Flux.Data.DataLoader((x, y), batchsize=32) |> device;
model  = Local2D(d, d, cin, cout, identity) |> device;
loss(x, y) = Flux.mse(model(x), y)
@time Flux.train!(loss, Flux.params(model), trdata, ADAM())

It should be possible to do this without making slices, and this is likely to be quicker. A few ways are:

julia> x = rand(256,256,4,32); W = rand(256,256,4,2);

julia> using TensorCast, OMEinsum, Tullio, LoopVectorization

julia> @reduce mid_1[i,j,k,n] := sum(c) W[i,j,c,k] * x[i,j,c,n];

julia> @ein mid_2[i,j,k,n] := W[i,j,c,k] * x[i,j,c,n];

julia> @tullio mid_3[i,j,k,n] := W[i,j,c,k] * x[i,j,c,n];

julia> size(mid_1)
(256, 256, 2, 32)

julia> mid_1 ≈ mid_2 ≈ mid_3 ≈ cat([sum(W[:,:,:,i] .* x, dims=3) for i in 1:size(W,4)]..., dims=3)

julia> out = σ.(mid_1 .+ b) ^C

These do different things, @reduce is just broadcasting, it first uses reshape(x, size(x)[1:3]...,1,:) to place its 4th dimension 5th. @ein will (I think) call batched matrix multiplication here, after using permutedims to line things up. (You could write either of those out by hand, too.) @tullio generates loops directly.

julia> @pretty @reduce mid_1[i,j,k,n] := sum(c) W[i,j,c,k] * x[i,j,c,n];
    local lyrebird = orient(x, (:, :, :, *, :))
    mid_1 = dropdims(sum(@__dot__(W * lyrebird), dims = 3), dims = 3)

julia> @btime @reduce mid_1[i,j,k,n] := sum(c) $W[i,j,c,k] * $x[i,j,c,n];
  70.918 ms (42 allocations: 160.00 MiB)

julia> @btime @ein mid_2[i,j,k,n] := $W[i,j,c,k] * $x[i,j,c,n];
  36.443 ms (324 allocations: 132.03 MiB)

julia> @btime @tullio mid_3[i,j,k,n] := $W[i,j,c,k] * $x[i,j,c,n];
  10.172 ms (76 allocations: 32.01 MiB)

julia> @btime cat([sum($W[:,:,:,i] .* $x, dims=3) for i in 1:size($W,4)]..., dims=3);
  56.664 ms (79 allocations: 196.00 MiB)

How fast these are will depend a lot on what size the arrays are, esp. how many channels. Times on the GPU may also look very different, and the gradient will be slower to a varying degree.

1 Like

Thanks @mcabbott! A very brief test on my example indicates that @tullio seems most promising for CPU and @ein for GPU. More extensive tests are needed though.