Flux: implementing efficient RBF layer

Hi,
I am new to Flux and could use some feedback on how to optimize some code.
Using the source code for the Dense layer as a reference I made a radial basis function (RBF) layer. I am able to train networks using the new RBF layer, but it is significantly slower than the corresponding Dense layer. In the example I included below the forward pass takes around 75 times as long for the RBF network than for the MLP. For the backwards pass the RBF network takes around 50 times longer than the MLP.

I was expecting the RBF layer to be a bit slower, but not this much slower, so I hope there is a way to speed it up. I have included a minimal working example below.

using Flux
using Flux.Losses: logitcrossentropy
using Flux: onehotbatch
using BenchmarkTools
# ---Radial basis function layer---
"""Radial basis function layer"""
struct rbf{S<:AbstractArray, T<:AbstractArray}
    V::S
    β::T
end

function rbf(in::Integer, out::Integer;
             initV = Flux.glorot_uniform, β0 = 1.0f0)
    V = initV(in, out)
    β = β0*ones(Float32, out)
    return rbf(V, β)
end

Flux.@functor rbf
function (a::rbf)(x::AbstractArray)

    batchsize = size(x)[2]
    numIn = size(x)[1] 
    numOut = size(a.V)[2] # number of units in the RBF layer

    #= a.V and x are matrices, with the same number of rows, but different numbers of columns.
    Each column of x represents a different datapoint, and each column of V is a template/centroid.
    For each datapoint we compute the squared Euclidean distance to each of the columns of V. =#
    d = a.V.-reshape(x, (numIn, 1, batchsize))
    d = (sum(abs2, d, dims=1))
    # Here size(d) = (1, numOut, batchsize) so next the singleton dimension is dropped
    d = reshape(d, (numOut, batchsize))
    return exp.(-a.β.*d)
end


# Generate a batch of 128 dummy datapoints
x = rand(Float32, 784,128)
y = onehotbatch(rand(0:9, 128), 0:9)

#Initialize network
rbfNet = Chain(rbf(784, 100), Dense(100, 10))
mlpNet = Chain(Dense(784, 100), Dense(100, 10))

println("Timing of forward pass in RBF network")
@btime rbfNet(x)
println("Timing of forward pass in MLP network")
@btime mlpNet(x)

θrbf = Flux.params(rbfNet)
println("Timing of backward pass in RBF network")
@btime ∇θrbf = gradient(() -> logitcrossentropy(rbfNet(x), y), θrbf) # compute gradient

θmlp = Flux.params(mlpNet)
println("Timing of backward pass in MLP network")
@btime ∇θmlp = gradient(() -> logitcrossentropy(mlpNet(x), y), θmlp) # compute gradient

Running this gives me the following output (training on CPU):

Timing of forward pass in RBF network
  9.407 ms (13 allocations: 38.39 MiB)
Timing of forward pass in MLP network
  128.350 μs (6 allocations: 110.41 KiB)
Timing of backward pass in RBF network
  22.873 ms (39184 allocations: 78.88 MiB)
Timing of backward pass in MLP network
  464.771 μs (669 allocations: 926.17 KiB)
1 Like

You can try with converting y to Float32

Thanks for the suggestion dhairyagandhi96.
y is a matrix of one-hot columns (one per datapoint). Each matrix element is of type Bool.
Converting from Bool to Float32 did not affect performance.

The biggest bottleneck line appears to be this one:

Which allocates a 784x100x128 array only to sum over it immediately afterwards. Avoiding this materialization requires some way to “push-down” the sum. e.g:

using Tullio

...

function (a::rbf)(x)
    # a.V and x are matrices, with the same number of rows, but different numbers of columns.
    # Each column of x represents a different datapoint, and each column of V is a template/centroid.
    # For each datapoint we compute the squared Euclidean distance to each of the columns of V.
    V, β = a.V, a.β
    @tullio d[num_out, batch_size] := abs2(V[num_in, num_out] - x[num_in, batch_size])
    return exp.(-β .* d)
end

This cuts the backwards pass time in about 1/3:

Timing of forward pass in RBF network
  442.935 μs (57 allocations: 113.27 KiB)
Timing of forward pass in MLP network
  103.125 μs (6 allocations: 110.41 KiB)
Timing of backward pass in RBF network
  6.539 ms (39270 allocations: 1.89 MiB)
Timing of backward pass in MLP network
  438.145 μs (726 allocations: 924.75 KiB)

There may be a way to cut this down further by not materializing d as well, but my Tullio-fu is not good enough to do it :slight_smile:

If you’re still looking for more performance or control, another avenue would be to write a custom kernel function and adjoint. This would allow you to use loops, in-place accumulation and all matter of tricks but is obviously more complex.

2 Likes

Thanks ToucheSir! That is a huge improvement!
Using your fix on a 16-core CPU the forwards pass is actually a bit faster for the RBF network in Julia 1.5.2. And the gradient computation only takes 5x longer for the RBF network than for the MLP (as opposed to 50x longer in my original implementation). Tullio seems to really take advantage of all cores! I will definitely look more into this package.

Here is the output I get when using your fix in Julia 1.5.2:

Timing of forward pass in RBF network
  125.876 μs (250 allocations: 126.89 KiB)
Timing of forward pass in MLP network
  142.898 μs (6 allocations: 110.41 KiB)
Timing of backward pass in RBF network
  2.630 ms (39632 allocations: 1.91 MiB)
Timing of backward pass in MLP network
  516.106 μs (654 allocations: 918.89 KiB)

Curiously in Julia 1.6.0 the RBF network performs a bit worse and the MLP performs a bit better:

Timing of forward pass in RBF network
  202.048 μs (249 allocations: 122.64 KiB)
Timing of forward pass in MLP network
  129.502 μs (6 allocations: 110.41 KiB)
Timing of backward pass in RBF network
  3.297 ms (39659 allocations: 1.91 MiB)
Timing of backward pass in MLP network
  471.975 μs (669 allocations: 926.17 KiB)

I am not sure how Tullio achieves parallelization, but perhaps the difference in performance is related to this issue: I/O (with parallelism) slower in Julia 1.6.0 and 1.7 · Issue #39598 · JuliaLang/julia · GitHub.

1 Like

I found an even faster and simpler approach in this paper: [1812.01214] Prototype-based Neural Network Layers: Incorporating Vector Quantization
The solution is annoyingly simple:)

for a datapoint x and a single template v_k (kth column of V) computing the euclidean distance amounts to:
d = ||x-v_k||^2 = -2v_k^T x + ||x||^2 + ||v_k||^2 = -2v_k + b(x, v_k)

This is analogous to a dense layer except that the bias is not a free parameter but a function of the template v_k and the input x.
In Flux this can be implemented (for an entire batch of datapoints and an entire layer of prototype units) as:

struct rbf{S<:AbstractArray, T<:AbstractArray}
    V::S
    β::T
end

function rbf(in::Integer, out::Integer;
             initV = Flux.glorot_uniform, β0 = 1.0f0)
    V = initV(out, in)
    β = β0*ones(Float32, out)
    return rbf(V, β)
end

Flux.@functor rbf
function (a::rbf)(x::AbstractArray)

    batchsize = size(x)[2]
    numIn = size(x)[1] 
    numOut = size(a.V)[2] # number of units in the RBF layer
    V, β = a.V, a.β

    x2 = sum(abs2, x, dims=1)
    V2 = sum(abs2, V, dims=2)
    d = -2*V*x .+ V2 .+ x2

    return exp.(-a.β.*d)
end

The forward pass is a little slower than the Tullio solution, but the backwards pass is significantly faster:

Timing of forward pass in RBF network
  249.558 μs (17 allocations: 468.47 KiB)
Timing of forward pass in MLP network
  142.527 μs (6 allocations: 110.41 KiB)
Timing of backward pass in RBF network
  1.281 ms (39183 allocations: 3.94 MiB)
Timing of backward pass in MLP network
  476.413 μs (656 allocations: 918.95 KiB)
1 Like