Consider the following code example:
using LinearAlgebra
using Flux
using BenchmarkTools
function diagmatmul(m, xs)
ys = m.(xs)
zs = diag.((first(ys)',) .* ys)
return sum(sum(zs))
end
function colwise(m, xs)
ys = m.(xs)
y₁ = view.(Ref(first(ys)), :, 1:batch_size)
zs = (y -> y₁ .⋅ view.(Ref(y), :, 1:batch_size)).(ys)
return sum(sum(zs))
end
function ∇diagmatmul(m, xs, θ)
gradient(θ) do
diagmatmul(m, xs)
end
end
function ∇colwise(m, xs, θ)
gradient(θ) do
colwise(m, xs)
end
end
time = 1:256;
batch_size = 128;
in, out = 64, 32;
m = Dense(in, out)
θ = Flux.params(m);
xs = [rand(Float32, in, batch_size) for _ ∈ time];
Functions diagmatmul
and colwise
are doing the same thing: they are computing inner product between columns of y₁
and the respective columns of each matrix in ys
, but implemented differently. The diagmatmul
is more wasteful, since it also computes unnecessary inner products between non-respective columns as well (non–diagonal entries in the result of matmultiplication), so colwise
should be faster. However, this is not true at all when it comes to computing their gradients:
julia> @benchmark diagmatmul($m, $xs)
BenchmarkTools.Trial:
memory estimate: 24.39 MiB
allocs estimate: 1547
--------------
minimum time: 11.103 ms (0.00% GC)
median time: 12.422 ms (0.00% GC)
mean time: 15.586 ms (21.91% GC)
maximum time: 30.666 ms (62.17% GC)
--------------
samples: 321
evals/sample: 1
julia> @benchmark colwise($m, $xs)
BenchmarkTools.Trial:
memory estimate: 9.93 MiB
allocs estimate: 36243
--------------
minimum time: 3.370 ms (0.00% GC)
median time: 3.762 ms (0.00% GC)
mean time: 5.116 ms (25.37% GC)
maximum time: 24.802 ms (78.08% GC)
--------------
samples: 978
evals/sample: 1
julia> @benchmark ∇diagmatmul($m, $xs, $θ)
BenchmarkTools.Trial:
memory estimate: 51.01 MiB
allocs estimate: 12981
--------------
minimum time: 20.382 ms (0.00% GC)
median time: 21.960 ms (0.00% GC)
mean time: 29.627 ms (27.53% GC)
maximum time: 47.773 ms (50.93% GC)
--------------
samples: 169
evals/sample: 1
julia> @benchmark ∇colwise($m, $xs, $θ)
BenchmarkTools.Trial:
memory estimate: 1.09 GiB
allocs estimate: 974619
--------------
minimum time: 456.830 ms (34.88% GC)
median time: 477.833 ms (33.95% GC)
mean time: 511.531 ms (39.05% GC)
maximum time: 835.988 ms (62.42% GC)
--------------
samples: 10
evals/sample: 1
So colwise
is ~3.3
times faster than the diagmatmul
, yet it’s gradient is more than 22
slower than that of diagmatmul
. So my questions are:
- What is causing this problem?
- How to make the gradient of
colwise
as efficient as it’s forward pass?
For context, the code in diagmatmul
was part of my forward pass function in my project. I decided to optimize it by rewriting it as in colwise
and saw increase in speed as expected, but then I got huuuge regressions in optimized version during backward pass.