Hi,
I am new to Flux and could use some feedback on how to optimize some code.
Using the source code for the Dense layer as a reference I made a radial basis function (RBF) layer. I am able to train networks using the new RBF layer, but it is significantly slower than the corresponding Dense layer. In the example I included below the forward pass takes around 75 times as long for the RBF network than for the MLP. For the backwards pass the RBF network takes around 50 times longer than the MLP.
I was expecting the RBF layer to be a bit slower, but not this much slower, so I hope there is a way to speed it up. I have included a minimal working example below.
using Flux
using Flux.Losses: logitcrossentropy
using Flux: onehotbatch
using BenchmarkTools
# ---Radial basis function layer---
"""Radial basis function layer"""
struct rbf{S<:AbstractArray, T<:AbstractArray}
V::S
β::T
end
function rbf(in::Integer, out::Integer;
initV = Flux.glorot_uniform, β0 = 1.0f0)
V = initV(in, out)
β = β0*ones(Float32, out)
return rbf(V, β)
end
Flux.@functor rbf
function (a::rbf)(x::AbstractArray)
batchsize = size(x)[2]
numIn = size(x)[1]
numOut = size(a.V)[2] # number of units in the RBF layer
#= a.V and x are matrices, with the same number of rows, but different numbers of columns.
Each column of x represents a different datapoint, and each column of V is a template/centroid.
For each datapoint we compute the squared Euclidean distance to each of the columns of V. =#
d = a.V.-reshape(x, (numIn, 1, batchsize))
d = (sum(abs2, d, dims=1))
# Here size(d) = (1, numOut, batchsize) so next the singleton dimension is dropped
d = reshape(d, (numOut, batchsize))
return exp.(-a.β.*d)
end
# Generate a batch of 128 dummy datapoints
x = rand(Float32, 784,128)
y = onehotbatch(rand(0:9, 128), 0:9)
#Initialize network
rbfNet = Chain(rbf(784, 100), Dense(100, 10))
mlpNet = Chain(Dense(784, 100), Dense(100, 10))
println("Timing of forward pass in RBF network")
@btime rbfNet(x)
println("Timing of forward pass in MLP network")
@btime mlpNet(x)
θrbf = Flux.params(rbfNet)
println("Timing of backward pass in RBF network")
@btime ∇θrbf = gradient(() -> logitcrossentropy(rbfNet(x), y), θrbf) # compute gradient
θmlp = Flux.params(mlpNet)
println("Timing of backward pass in MLP network")
@btime ∇θmlp = gradient(() -> logitcrossentropy(mlpNet(x), y), θmlp) # compute gradient
Running this gives me the following output (training on CPU):
Timing of forward pass in RBF network
9.407 ms (13 allocations: 38.39 MiB)
Timing of forward pass in MLP network
128.350 μs (6 allocations: 110.41 KiB)
Timing of backward pass in RBF network
22.873 ms (39184 allocations: 78.88 MiB)
Timing of backward pass in MLP network
464.771 μs (669 allocations: 926.17 KiB)