Multithreading in many nested loops

I am trying to optimize a function which performs a convolution-operation known from neuronal networks.
I have already tried to optimize the function as much as possible (e.g. code_warntype etc.), now I would like to give multithreading a try. However, I am not sure what would be the best method for parallelization here because of the many nested loops. Do you have any indeas on how to further optimize the function?

A simplified version of my code:

# not really important
function get_input_position(output_position::Tuple, stride::Tuple)
    m = output_position[1] + (stride[1] - 1) * (output_position[1] - 1)
    n = output_position[2] + (stride[2] - 1) * (output_position[2] - 1)

    return m, n

# not really important
function calculate_output_shape(input_height::Int, input_width::Int, kernel_height::Int, kernel_width::Int; stride::Tuple=(1, 1), padding::Tuple=(0, 0))
    output_height = (input_height + 2 * padding[1] - kernel_height) / stride[1] + 1
    output_width = (input_width + 2 * padding[2] - kernel_width) / stride[2] + 1
    output_height = convert(Int, trunc(output_height))
    output_width = convert(Int, trunc(output_width))

    return output_height, output_width

# the function to optimize
function multichannel_conv(inputs::Array{Float64, 4}, kernels::Array{Float64, 4}; stride::Tuple{Int, Int}=(1, 1), padding::Tuple{Int, Int}=(0, 0))
    # inputs = copy(inputs)
    # storing all the necessary shapes
    current_batch_size::Int, in_channels::Int, input_height::Int, input_width::Int = size(inputs)
    out_channels::Int, in_channels, kernel_height::Int, kernel_width::Int = size(kernels)
    # calculating shape of output
    output_height::Int, output_width::Int = calculate_output_shape(input_height, input_width, kernel_height, kernel_width, stride=stride, padding=padding)

    output::Array{Float64, 4} = Array{Float64, 4}(undef, current_batch_size, out_channels, output_height, output_width)
    # actual computation
    for index_batch in 1:current_batch_size
        for out_channel in 1:out_channels # Threads.@threads
            for y_out in 1:output_height, x_out in 1:output_width
            # Threads.@threads for y_out in 1:output_height
            # Threads.@threads for x_out in 1:output_width
                m, n = get_input_position((y_out, x_out), stride)
                value::Float64 = 0.00
                for in_channel in 1:in_channels

                    for y_w in 1:kernel_height, x_w in 1:kernel_width
                    # Threads.@threads for y_w in 1:kernel_height
                    # Threads.@threads for x_w in 1:kernel_width
                        y_in::Int = m + y_w - 1
                        x_in::Int = n + x_w - 1
                        value += inputs[index_batch, in_channel, y_in, x_in] * kernels[out_channel, in_channel, y_w, x_w]
                    # end
                output[index_batch, out_channel, y_out, x_out] = value
            # end

    return output

Just for measuring runtime:

using BenchmarkTools
output = @btime multichannel_conv(rand(64, 20, 150, 150), rand(40, 20, 3, 3))
>>> 37.964 s (7 allocations: 647.59 MiB)

For multithreading I would try and block the most work together, into the largest blocks possible, but making sure you have just enough for all your threads. So parallelising over as few loops as possible (and outer ones) to take advantage of cache locality instead of going around all over the place. Basically, keeping each thread accessing the memory of the array in order should help performance.

You can also combine your nested for loops into a single one using product

using Iterators: product

for a, b, c in product(iterA, iterB, iterC)

# is the same as
for a in iterA
    for b in iterB
        for c in iterC
... Etc

You can stick a Threads.@threads on the first loop and parallelise over that. I often find the above very helpful in these situations.

The most important thing is to benchmark threaded vs serial relative performance against the size of your problem. Ideally, you want be as close to an n factor increase, where n is the number of threads.