using BenchmarkTools
using Flux
conv = Conv((7,7), 3 => 64; stride=2, pad=3)
dummy = randn(Float32, 224, 224, 3, 2);
conv(dummy) # ignore compile time
@benchmark conv(dummy)
#=
BenchmarkTools.Trial:
memory estimate: 19.29 MiB
allocs estimate: 60
--------------
minimum time: 22.996 ms (0.00% GC)
median time: 25.009 ms (0.00% GC)
mean time: 30.242 ms (17.79% GC)
maximum time: 66.116 ms (62.37% GC)
--------------
samples: 165
evals/sample: 1
=#
using PyCall
const torch = pyimport_conda("torch", "torch")
const nn = pyimport("torch.nn")
dummy_py = torch.randn(2, 3, 224, 224);
conv1_py = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=true)
conv1_py(dummy_py)
@benchmark conv1_py(dummy_py)
#=
BenchmarkTools.Trial:
memory estimate: 160 bytes
allocs estimate: 4
--------------
minimum time: 8.824 ms (0.00% GC)
median time: 9.436 ms (0.00% GC)
mean time: 10.463 ms (0.00% GC)
maximum time: 19.719 ms (0.00% GC)
--------------
samples: 478
evals/sample: 1
=#
Is the result expected, or a benchmark/implementation mistake? This is the first layer of ResNet18 (with extra bias), the full network implementation gets similar results.
This is probably not responsible for the entire 2x speed difference, but I have noticed Flux.Conv layers are not type-stable, which I guess would reduce performance. Opened a github issue here and a thread here.
Because of the lack of response I am unsure whether this type-instability is actually a big deal though, I am quite new to Julia.
Bias calculation could make a difference, but the benchmark uses the identity function activation so I think it would be compiled away and have zero overhead, no?
Do you mind trying again with the functional conv functions in NNlib/PyTorch? I think it would be good to open an issue in Flux or NNlib, so this would help narrow things down.
Edit: using MKL.jl helps narrow the gap on my machine:
In [1]: import torch.nn.functional as F
In [2]: import torch
In [3]: x = torch.randn(2, 3, 224, 224)
In [4]: conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=True)
In [5]: %timeit conv1(x)
2.78 ms ± 70.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
julia> using LinearAlgebra: BLAS
julia> BLAS.vendor()
:mkl
julia> using BenchmarkTools, Flux
julia> const x = randn(Float32, 224, 224, 3, 2);
julia> const conv1 = Conv((7,7), 3 => 64; stride=2, pad=3);
julia> @benchmark conv1($x)
BenchmarkTools.Trial:
memory estimate: 19.29 MiB
allocs estimate: 45
--------------
minimum time: 4.588 ms (0.00% GC)
median time: 5.330 ms (0.00% GC)
mean time: 5.424 ms (4.99% GC)
maximum time: 10.235 ms (7.86% GC)
--------------
samples: 922
evals/sample: 1