I am trying to optimize a function which performs a convolution-operation known from neuronal networks.
I have already tried to optimize the function as much as possible (e.g. code_warntype etc.), now I would like to give multithreading a try. However, I am not sure what would be the best method for parallelization here because of the many nested loops. Do you have any indeas on how to further optimize the function?
A simplified version of my code:
# not really important
function get_input_position(output_position::Tuple, stride::Tuple)
m = output_position[1] + (stride[1] - 1) * (output_position[1] - 1)
n = output_position[2] + (stride[2] - 1) * (output_position[2] - 1)
return m, n
end
# not really important
function calculate_output_shape(input_height::Int, input_width::Int, kernel_height::Int, kernel_width::Int; stride::Tuple=(1, 1), padding::Tuple=(0, 0))
output_height = (input_height + 2 * padding[1] - kernel_height) / stride[1] + 1
output_width = (input_width + 2 * padding[2] - kernel_width) / stride[2] + 1
output_height = convert(Int, trunc(output_height))
output_width = convert(Int, trunc(output_width))
return output_height, output_width
end
# the function to optimize
function multichannel_conv(inputs::Array{Float64, 4}, kernels::Array{Float64, 4}; stride::Tuple{Int, Int}=(1, 1), padding::Tuple{Int, Int}=(0, 0))
# inputs = copy(inputs)
# storing all the necessary shapes
current_batch_size::Int, in_channels::Int, input_height::Int, input_width::Int = size(inputs)
out_channels::Int, in_channels, kernel_height::Int, kernel_width::Int = size(kernels)
# calculating shape of output
output_height::Int, output_width::Int = calculate_output_shape(input_height, input_width, kernel_height, kernel_width, stride=stride, padding=padding)
output::Array{Float64, 4} = Array{Float64, 4}(undef, current_batch_size, out_channels, output_height, output_width)
# actual computation
for index_batch in 1:current_batch_size
for out_channel in 1:out_channels # Threads.@threads
for y_out in 1:output_height, x_out in 1:output_width
# Threads.@threads for y_out in 1:output_height
# Threads.@threads for x_out in 1:output_width
m, n = get_input_position((y_out, x_out), stride)
value::Float64 = 0.00
for in_channel in 1:in_channels
for y_w in 1:kernel_height, x_w in 1:kernel_width
# Threads.@threads for y_w in 1:kernel_height
# Threads.@threads for x_w in 1:kernel_width
y_in::Int = m + y_w - 1
x_in::Int = n + x_w - 1
value += inputs[index_batch, in_channel, y_in, x_in] * kernels[out_channel, in_channel, y_w, x_w]
# end
end
end
output[index_batch, out_channel, y_out, x_out] = value
end
# end
end
end
return output
end
Just for measuring runtime:
using BenchmarkTools
output = @btime multichannel_conv(rand(64, 20, 150, 150), rand(40, 20, 3, 3))
>>> 37.964 s (7 allocations: 647.59 MiB)