Indeed by caching the convolution ids in the layer struct I got a very nice speed up that makes it usable for small jobs :
Here a self-standing script:
"Part of [BetaML](https://github.com/sylvaticus/BetaML.jl). Licence is MIT."
# Experimental
abstract type AbstractLayer end
using Random, LinearAlgebra, StaticArrays, LoopVectorization, Distributions
import Base.size
"""
ConvLayer
Representation of a convolutional layer in the network
"""
mutable struct ConvLayer{ND,NDPLUS1,NDPLUS2} <: AbstractLayer
"Input size (including nchannel_in as last dimension)"
input_size::SVector{NDPLUS1,Int64}
"Weight tensor (aka \"filter\" or \"kernel\") with respect to the input from previous layer or data (kernel_size array augmented by the nchannels_in and nchannels_out dimensions)"
weight::Array{Float64,NDPLUS2}
"Wether to use (and learn) a bias weigth [def: true]"
usebias::Bool
"Bias (nchannels_out array)"
bias::Array{Float64,1}
"Padding (initial)"
padding_start::SVector{ND,Int64}
"Padding (ending)"
padding_end::SVector{ND,Int64}
"Stride"
stride::SVector{ND,Int64}
"Number of dimensions (excluding input and output channels)"
ndims::Int64
"Activation function"
f::Function
"Derivative of the activation function"
df::Union{Function,Nothing}
"x ids of the convolution (computed in `preprocessing`` - itself at the beginning of `train`"
x_ids::Array{NTuple{NDPLUS1,Int32},1}
"y ids of the convolution (computed in `preprocessing`` - itself at the beginning of `train`"
y_ids::Array{NTuple{NDPLUS1,Int32},1}
"w ids of the convolution (computed in `preprocessing`` - itself at the beginning of `train`"
w_ids::Array{NTuple{NDPLUS2,Int32},1}
"""
ConvLayer()
Instantiate a new nD-dimensional, possibly multichannel ConvolutionalLayer
The input data is either a column vector (in which case is reshaped) or an array of `input_size` augmented by the `n_channels` dimension, the output size depends on the `input_size`, `kernel_size`, `padding` and `striding` but has always `nchannels_out` as its last dimention.
# Positional arguments:
* `input_size`: Shape of the input layer (integer for 1D convolution, tuple otherwise). Do not consider the channels number here.
* `kernel_size`: Size of the kernel (aka filter or learnable weights) (integer for 1D or hypercube kernels or nD-sized tuple for assymmetric kernels). Do not consider the channels number here.
* `nchannels_in`: Number of channels in input
* `nchannels_out`: Number of channels in output
# Keyword arguments:
* `kernel_init`: Initial weigths with respect to the input [default: Xavier initialisation]. If given, it should be a multidimensional array of `kernel_size` augmented by `nchannels_in` and `nchannels_out` dimensions
* `bias_init`: Initial weigths with respect to the bias [default: Xavier initialisation]. If given it should be a `nchannels_out` vector of scalars.
* `f`: Activation function [def: `relu`]
* `df`: Derivative of the activation function [default: `nothing` (i.e. use AD)]
* `rng`: Random Number Generator (see [`FIXEDSEED`](@ref)) [deafult: `Random.GLOBAL_RNG`]
# Notes:
- Xavier initialization is sampled from a `Uniform` distribution between `⨦ sqrt(6/(prod(input_size)*nchannels_in))`
- to retrieve the output size of the layer, use `size(ConvLayer[2])`. The output size on each dimension _d_ (except the last one that is given by `nchannels_out`) is given by the following formula (ceiled): `output_size[d] = 1 + (input_size[d]+2*padding[d]-kernel_size[d])/stride[d]`
"""
function ConvLayer(input_size,kernel_size,nchannels_in,nchannels_out;
stride = (ones(Int64,length(input_size))...,),
rng = Random.GLOBAL_RNG,
padding = nothing, # zeros(Int64,length(input_size)),
kernel_init = rand(rng, Uniform(-sqrt(6/(prod(input_size)*nchannels_in)),sqrt(6/(prod(input_size)*nchannels_in))),(kernel_size...,nchannels_in,nchannels_out)...),
usebias = true,
bias_init = usebias ? rand(rng, Uniform(-sqrt(6/(prod(input_size)*nchannels_in)),sqrt(6/(prod(input_size)*nchannels_in))),nchannels_out) : zeros(Float64,nchannels_out),
f = identity,
df = nothing)
# be sure all are tuples of right dimension...
if typeof(input_size) <: Integer
input_size = (input_size,)
end
nD = length(input_size)
if typeof(kernel_size) <: Integer
kernel_size = ([kernel_size for d in 1:nD]...,)
end
length(input_size) == length(kernel_size) || error("Number of dimensions of the kernel must equate number of dimensions of input data")
if typeof(stride) <: Integer
stride = ([stride for d in 1:nD]...,)
end
if typeof(padding) <: Integer
padding_start = ([padding for d in 1:nD]...,)
padding_end = ([padding for d in 1:nD]...,)
elseif isnothing(padding) # compute padding to keep same size/stride if not provided
target_out_size = [Int(round(input_size[d]/stride[d])) for d in 1:length(input_size)]
padding_total = [(target_out_size[d]-1)*stride[d] - input_size[d]+kernel_size[d] for d in 1:length(input_size)]
padding_start = Int.(ceil.(padding_total ./ 2))
padding_end = padding_total .- padding_start
else
padding_start = padding[1]
padding_end = padding[2]
end
nD == length(stride) || error("`stride` must be either a scalar or a tuple that equates the number of dimensions of input data")
nD == length(padding_start) == length(padding_end) || error("`padding` must be: (a) the value `nothing` for automatic computation, (b) a scalar for same padding on all dimensions or (c) a 2-elements tuple where each elements are tuples that equate the number of dimensions of input data for indicating the padding to set in front of the data and the padding to set at the ending of the data")
new{nD,nD+1,nD+2}((input_size...,nchannels_in),kernel_init,usebias,bias_init,padding_start,padding_end,stride,nD,f,df,[],[],[])
end
end
function preprocess!(layer::ConvLayer{ND,NDPLUS1,NDPLUS2}) where {ND,NDPLUS1,NDPLUS2}
if length(layer.x_ids) > 0
empty!(layer.x_ids)
empty!(layer.w_ids)
empty!(layer.y_ids)
end
input_size, output_size = size(layer)
nchannels_out = output_size[end]
nchannels_in = input_size[end]
convsize = input_size[1:end-1]
ndims_conv = ND
wsize = size(layer.weight)
ysize = output_size
# preallocating temp variables
w_idx = Array{Int32,1}(undef,NDPLUS2)
y_idx = Array{Int32,1}(undef,NDPLUS1)
w_idx_conv = Array{Int32,1}(undef,ND)
y_idx_conv = Array{Int32,1}(undef,ND)
idx_x_source_padded = Array{Int32,1}(undef,ND)
checkstart = Array{Bool,1}(undef,ND)
checkend = Array{Bool,1}(undef,ND)
x_idx = Array{Int32,1}(undef,NDPLUS1)
@inbounds for nch_in in 1:nchannels_in
@inbounds for nch_out in 1:nchannels_out
@inbounds for w_idx_conv in CartesianIndices( ((wsize[1:end-2]...),) )
w_idx_conv = Tuple(w_idx_conv)
w_idx = (w_idx_conv...,nch_in,nch_out)
@inbounds for y_idx_conv in CartesianIndices( ((ysize[1:end-1]...),) )
y_idx_conv = Tuple(y_idx_conv)
y_idx = (y_idx_conv...,nch_out)
check = true
@inbounds for d in 1:ndims_conv
idx_x_source_padded[d] = w_idx_conv[d] + (y_idx_conv[d] - 1 ) * layer.stride[d]
checkstart[d] = idx_x_source_padded[d] > layer.padding_start[d]
checkend[d] = idx_x_source_padded[d] <= layer.padding_start[d] .+ convsize[d]
checkstart[d] && checkend[d] || begin check = false; break; end
end
check || continue
@inbounds @simd for d in 1:ndims_conv
x_idx[d] = idx_x_source_padded[d] - layer.padding_start[d]
end
x_idx[ndims_conv+1] = nch_in
push!(layer.x_ids,((x_idx...,)))
push!(layer.w_ids,w_idx)
push!(layer.y_ids,y_idx)
end
end
end
end
end
function _xpadComp(layer::ConvLayer,x)
input_size, output_size = size(layer)
# input padding
padding_start = SVector{layer.ndims+1,Int64}([layer.padding_start...,0])
padding_end = SVector{layer.ndims+1,Int64}([layer.padding_end...,0])
padded_size = input_size .+ padding_start .+ padding_end
xstart = padding_start .+ 1
xends = padding_start .+ input_size
xpadded = zeros(eltype(x), padded_size...)
xpadded[[range(s,e,step=1) for (s,e) in zip(xstart,xends)]...] = x
return xpadded
end
function _zComp_old(layer::ConvLayer,x)
if ndims(x) == 1
reshape(x,size(layer)[1])
end
input_size, output_size = size(layer)
# input padding
xpadded = _xpadComp(layer,x)
y = zeros(output_size)
for yi in CartesianIndices(y)
yiarr = Tuple(yi)
nchannel_out = yiarr[end]
starti = yiarr[1:end-1] .* layer.stride .- layer.stride .+ 1
starti = vcat(starti,1)
endi = starti .+ convert(SVector{layer.ndims+1,Int64},size(layer.weight)[1:end-1]) .- 1
weight = selectdim(layer.weight,layer.ndims+2,nchannel_out)
xpadded_in = xpadded[[range(s,e,step=1) for (s,e) in zip(starti,endi)]...]
y_unbias = @turbo dot(weight, xpadded_in)
y[yi] = layer.bias[nchannel_out] .+ y_unbias
end
return y
end
function _zComp(layer::ConvLayer{ND,NDPLUS1,NDPLUS2},x) where {ND,NDPLUS1,NDPLUS2} # slower :-/
input_size, output_size = size(layer)
nchannels_out = output_size[end]
nchannels_in = input_size[end]
convsize = input_size[1:end-1]
ndims_conv = ND
if ndims(x) == 1
reshape(x,size(layer)[1])
end
y = zeros(output_size)
lx_ids = layer.x_ids
ly_ids = layer.y_ids
lw_ids = layer.w_ids
lweight = layer.weight
@simd for idx in 1:length(layer.y_ids)
y[ly_ids[idx]...] += x[lx_ids[idx]...] * lweight[lw_ids[idx]...]
end
for ch_out in 1:nchannels_out
y_ch_out = selectdim(y,NDPLUS1,ch_out)
y_ch_out .+= layer.bias[ch_out]
end
return y
end
"""
forward()
Compute forward pass of a ConvLayer
"""
function forward_old(layer::ConvLayer,x)
z = _zComp_old(layer,x)
return layer.f.(z)
end
function forward(layer::ConvLayer,x)
z = _zComp(layer,x)
return layer.f.(z)
end
function backward_old(layer::ConvLayer,x,next_gradient) # with respect to inputs: derror/dx
z = _zComp_old(layer,x)
if layer.df != nothing
dfz = layer.df.(z)
else
dfz = layer.f'.(z) # using AD
end
dϵ_dz = @turbo dfz .* next_gradient
nchannels_out = size(layer.weight)[end]
nchannels_in = size(layer.weight)[end-1]
input_size, output_size = size(layer)
de_dx = zeros(layer.input_size...)
@inbounds for nch_in in 1:nchannels_in
w_ch_in = selectdim(layer.weight,layer.ndims+1,nch_in)
@inbounds for nch_out in 1:nchannels_out
w_ch_in_out = selectdim(w_ch_in,layer.ndims+1,nch_out)
dϵ_dz_ch_out = selectdim(dϵ_dz,layer.ndims+1,nch_out)
de_dx_ch_in = selectdim(de_dx,layer.ndims+1,nch_in)
@inbounds for w_idx in CartesianIndices(w_ch_in_out)
w_idx = Tuple(w_idx)
@inbounds for dey_idx in CartesianIndices(dϵ_dz_ch_out)
dey_idx = Tuple(dey_idx)
idx_x_source_padded = w_idx .+ (dey_idx .- 1 ) .* layer.stride
checkstart = idx_x_source_padded .> layer.padding_start
chekend = idx_x_source_padded .<= layer.padding_start .+ input_size[1:end-1]
if all(checkstart) && all(chekend)
idx_x_source = idx_x_source_padded .- layer.padding_start
de_dx_ch_in[idx_x_source...] += dϵ_dz_ch_out[dey_idx...] * w_ch_in_out[w_idx...]
end
end
end
end
end
return de_dx
end
function backward(layer::ConvLayer{ND,NDPLUS1,NDPLUS2},x, next_gradient) where {ND,NDPLUS1,NDPLUS2}
input_size, output_size = size(layer)
z = _zComp(layer,x)
if layer.df != nothing
dfz = layer.df.(z)
else
dfz = layer.f'.(z) # using AD
end
dϵ_dz = @turbo dfz .* next_gradient
de_dx = zeros(layer.input_size...)
lx_ids = layer.x_ids
ly_ids = layer.y_ids
lw_ids = layer.w_ids
lweight = layer.weight
@simd for idx in 1:length(layer.y_ids)
de_dx[lx_ids[idx]...] += dϵ_dz[ly_ids[idx]...] * lweight[lw_ids[idx]...]
end
return de_dx
end
function get_gradient_old(layer::ConvLayer,x,next_gradient) # derror/dw
z = _zComp_old(layer,x)
if layer.df != nothing
dfz = layer.df.(z)
else
dfz = layer.f'.(z) # using AD
end
dϵ_dz = @turbo dfz .* next_gradient
dw = zeros(size(layer.weight))
xpadded = _xpadComp(layer,x)
# dw computation
for idx in CartesianIndices(dw)
idx = Tuple(idx)
nchannel_out = idx[end]
nchannel_in = idx[end-1]
wdims = idx[1:end-2]
dϵ_dz_nchannelOut = selectdim(dϵ_dz,layer.ndims+1,nchannel_out)
xval = zeros(size(dϵ_dz_nchannelOut))
for yi in CartesianIndices(dϵ_dz_nchannelOut)
idx_y = Tuple(yi)
idx_x_dest = idx_y
idx_x_source = wdims .+ (idx_y .- 1 ) .* layer.stride # xpadded[i] = w[i] + (Y[i] -1 ) * STRIDE
xval[idx_x_dest...] = xpadded[vcat(idx_x_source,nchannel_in)...]
end
dw[idx...] = dot(xval,dϵ_dz_nchannelOut) # slighly more efficient than using += on each individual product
end
if layer.usebias
dbias = zeros(length(layer.bias))
for bias_idx in 1:length(layer.bias)
nchannel_out = bias_idx
dϵ_dz_nchannelOut = selectdim(dϵ_dz,layer.ndims+1,nchannel_out)
dbias[bias_idx] = sum(dϵ_dz_nchannelOut)
end
return (dw,dbias)
else
return (dw,)
end
end
function get_gradient(layer::ConvLayer{ND,NDPLUS1,NDPLUS2},x, next_gradient) where {ND,NDPLUS1,NDPLUS2}
z = _zComp(layer,x)
if layer.df != nothing
dfz = layer.df.(z)
else
dfz = layer.f'.(z) # using AD
end
dϵ_dz = @turbo dfz .* next_gradient
de_dw = zeros(size(layer.weight))
lx_ids = layer.x_ids
ly_ids = layer.y_ids
lw_ids = layer.w_ids
@simd for idx in 1:length(layer.y_ids)
de_dw[lw_ids[idx]...] += dϵ_dz[ly_ids[idx]...] * x[lx_ids[idx]...]
end
if layer.usebias
dbias = zeros(length(layer.bias))
for bias_idx in 1:length(layer.bias)
nchannel_out = bias_idx
dϵ_dz_nchannelOut = selectdim(dϵ_dz,layer.ndims+1,nchannel_out)
dbias[bias_idx] = sum(dϵ_dz_nchannelOut)
end
return (de_dw,dbias)
else
return (de_dw,)
end
end
"""
size()
Get the dimensions of the layers in terms of (dimensions in input, dimensions in output) including channels as last dimension
"""
function size(layer::ConvLayer)
nchannels_in = layer.input_size[end]
nchannels_out = size(layer.weight)[end]
in_size = (layer.input_size...,)
out_size = ([1 + Int(floor((layer.input_size[d]+layer.padding_start[d]+layer.padding_end[d]-size(layer.weight,d))/layer.stride[d])) for d in 1:layer.ndims]...,nchannels_out)
return (in_size,out_size)
end
And then we can use it with:
using BenchmarkTools
# Multiple in_channel/ out_channel
x = reshape(1:(32*32*3),32,32,3)
l = ConvLayer((32,32),(4,4),3,5,f=identity,df=x->1) #-> from 32x32x4 to ?x?x5 using a 4x4 filter df is provided to avoid zygote dependency in this script
preprocess!(l)
print("preproces btime: ")
@btime preprocess!($l)
y = forward_old(l,x)
print("forward_old btime: ")
@btime forward_old($l,$x)
y2 = forward(l,x)
y2 == y
print("forward btime: ")
@btime forward($l,$x)
de_dy = y ./ 10
de_dw_old = get_gradient_old(l,x,de_dy)
print("get_gradient_old btime: ")
@btime get_gradient_old($l,$x,$de_dy)
de_dw = get_gradient(l,x,de_dy)
de_dw_old[1] == de_dw[1]
print("get_gradient btime: ")
@btime get_gradient($l,$x,$de_dy)
de_dx_old = backward_old(l,x,de_dy)
print("backward_old btime: ")
@btime backward_old($l,$x,$de_dy)
de_dx = backward(l,x,de_dy)
de_dx_old[1] == de_dx[1]
print("backward btime: ")
@btime backward($l,$x,$de_dy)
that produces:
preproces btime:
6.481 ns (0 allocations: 0 bytes)
forward_old btime:
39.261 ms (302153 allocations: 14.48 MiB)
forward btime:
21.354 ms (691964 allocations: 10.64 MiB)
get_gradient_old btime:
882.561 ms (6585411 allocations: 221.10 MiB)
get_gradient btime:
73.475 ms (1806178 allocations: 27.68 MiB)
backward_old btime:
1.361 s (11706692 allocations: 443.24 MiB)
backward btime:
70.537 ms (1845173 allocations: 28.30 MiB)