# Innocent looking optimization of the forward pass causes performance cliff in gradient calculation with Zygote.jl v0.4.7

Consider the following code example:

``````using LinearAlgebra
using Flux
using BenchmarkTools

function diagmatmul(m, xs)
ys = m.(xs)
zs = diag.((first(ys)',) .* ys)
return sum(sum(zs))
end

function colwise(m, xs)
ys = m.(xs)
y₁ = view.(Ref(first(ys)), :, 1:batch_size)
zs = (y -> y₁ .⋅ view.(Ref(y), :, 1:batch_size)).(ys)
return sum(sum(zs))
end

function ∇diagmatmul(m, xs, θ)
diagmatmul(m, xs)
end
end

function ∇colwise(m, xs, θ)
colwise(m, xs)
end
end

time = 1:256;
batch_size = 128;
in, out = 64, 32;
m = Dense(in, out)
θ = Flux.params(m);
xs = [rand(Float32, in, batch_size) for _ ∈ time];
``````

Functions `diagmatmul` and `colwise` are doing the same thing: they are computing inner product between columns of `y₁` and the respective columns of each matrix in `ys`, but implemented differently. The `diagmatmul` is more wasteful, since it also computes unnecessary inner products between non-respective columns as well (non–diagonal entries in the result of matmultiplication), so `colwise` should be faster. However, this is not true at all when it comes to computing their gradients:

``````julia> @benchmark diagmatmul(\$m, \$xs)
BenchmarkTools.Trial:
memory estimate:  24.39 MiB
allocs estimate:  1547
--------------
minimum time:     11.103 ms (0.00% GC)
median time:      12.422 ms (0.00% GC)
mean time:        15.586 ms (21.91% GC)
maximum time:     30.666 ms (62.17% GC)
--------------
samples:          321
evals/sample:     1

julia> @benchmark colwise(\$m, \$xs)
BenchmarkTools.Trial:
memory estimate:  9.93 MiB
allocs estimate:  36243
--------------
minimum time:     3.370 ms (0.00% GC)
median time:      3.762 ms (0.00% GC)
mean time:        5.116 ms (25.37% GC)
maximum time:     24.802 ms (78.08% GC)
--------------
samples:          978
evals/sample:     1

julia> @benchmark ∇diagmatmul(\$m, \$xs, \$θ)
BenchmarkTools.Trial:
memory estimate:  51.01 MiB
allocs estimate:  12981
--------------
minimum time:     20.382 ms (0.00% GC)
median time:      21.960 ms (0.00% GC)
mean time:        29.627 ms (27.53% GC)
maximum time:     47.773 ms (50.93% GC)
--------------
samples:          169
evals/sample:     1

julia> @benchmark ∇colwise(\$m, \$xs, \$θ)
BenchmarkTools.Trial:
memory estimate:  1.09 GiB
allocs estimate:  974619
--------------
minimum time:     456.830 ms (34.88% GC)
median time:      477.833 ms (33.95% GC)
mean time:        511.531 ms (39.05% GC)
maximum time:     835.988 ms (62.42% GC)
--------------
samples:          10
evals/sample:     1
``````

So `colwise` is `~3.3` times faster than the `diagmatmul`, yet it’s gradient is more than `22` slower than that of `diagmatmul`. So my questions are:

• What is causing this problem?
• How to make the gradient of `colwise` as efficient as it’s forward pass?

For context, the code in `diagmatmul` was part of my forward pass function in my project. I decided to optimize it by rewriting it as in `colwise` and saw increase in speed as expected, but then I got huuuge regressions in optimized version during backward pass.

I’m not surprised that the one with many `view(...)`s is slow. In the backward pass, each of these makes a new dense array of zeros, and writes into it. It’s possible to avoid that, but perhaps unnecessary.

Here are a few variants, the fastest of which just keeps everything as giant arrays.

``````function diagmatmul(m, xs) # just adding comments to understand!
ys = m.(xs)                      # ys[t][o,b]
zs = diag.((first(ys)',) .* ys)  # zs[t][b] := sum(o) ys[1][o,b] * ys[t][o,b]
return sum(sum(zs))              # sum(t,b)
end

function usedot(m, xs)
ys = m.(xs)                     # ys[t][o,b]
z = dot.(Ref(first(ys)), ys)    # z[t] := sum(o,b) ys[1][o,b] * ys[t][o,b]
return sum(z)
end

using TensorCast

@cast xt[i,b\t] := xs[t][i,b]; # glue & reshape

function gluedup(m, xt)
yt = m(xt)
@cast ys[o,b,t] := yt[o,b\t]  b:batch_size; # just reshape
@reduce sum(o,b,t) ys[o,b,1] * ys[o,b,t]    # broadcast & sum
end
``````

From which I get the following times:

``````@btime ∇diagmatmul(\$m, \$xs, \$θ) #  25.432 ms,     51.01 MiB
@btime ∇colwise(\$m, \$xs, \$θ)    # 385.081 ms,  1.09 GiB
@btime ∇usedot(\$m, \$xs, \$θ)     #  12.438 ms,     34.65 MiB
@btime ∇gluedup(\$m, \$xt, \$θ)    #   9.640 ms,     36.03 MiB
``````

It ought to be possible to do better, as `sum(ys[:,:,1] * ys)` is broadcasting out a big array and then immediately summing it. But still an improvement!

3 Likes

Thank you @mcabbott, this is really helpful. However, it looks like I oversimplified my minimum working example: the double sum in `sum(sum(zs))` is there just to scalarize the output, so that I can take the gradient of that particular piece from the larger forward pass. In the actual code, `zs` from above are stacked into a matrix (using your awesome package), which is then passed through `softmax`, so like:

``````function diagmatmul(m, xs)
ys = m.(xs)
zs = diag.((first(ys)',) .* ys)
αs = softmax(stack(zs)')
end

function ∇diagmatmul(m, xs, θ)
sum(sum(diagmatmul(m, xs))
end
end
``````
``````function colwise(m, xs)
ys = m.(xs)
y₁ = view.(Ref(first(ys)), :, 1:batch_size)
zs = (y -> y₁ .⋅ view.(Ref(y), :, 1:batch_size)).(ys)
αs = softmax(stack(zs)')
end

function ∇colwise(m, xs, θ)
sum(sum(colwise(m, xs)))
end
end
``````

(there are even more steps after that in the forward pass function, but they are not relevant to the issue described here). So `usedot` and `gluedup` functions won’t solve the problem.

OK, then `usedot` is useless, but one big array still has a small edge…

``````using LazyStack, OMEinsum#master

function diagmatmul(m, xs)
ys = m.(xs);                     # ys[t][o,b]
zs = diag.((first(ys)',) .* ys); # zs[t][b] := sum(o) ys[1][o,b] * ys[t][o,b]
z = stack(zs)';                  # z[t,b] := zs[t][b]
αs = softmax(z);
end

function gluedup(m, xt)
yt = m(xt);
@cast ys[o,t,b] := yt[o,b\t]  b:batch_size, nolazy;
ys1 = ys[:,1,:];
@ein z[t,b] := ys1[o,b] * ys[o,t,b]; # batch matmul
αs_ = softmax(z);
end

αs[1:2, 1:2]
αs_[1:2, 1:2]

@btime ∇diagmatmul(\$m, \$xs, \$θ) #  27.626 ms
@btime ∇colwise(\$m, \$xs, \$θ)    # 401.620 ms
@btime ∇gluedup(\$m, \$xt, \$θ)    #  14.256 ms
``````

IIRC the batch matmul which this is calling still makes slices internally, but along `b`, at least on the CPU. It pays to make this the last index.

1 Like

Thanks a lot @mcabbott I ended up using your `gluedup ` version and it improved the performance by quite a bit.
Now that I see all the amazing stuff `TensorCast.jl` and `OMEinsum.jl` and the likes can do, I have a follow-up question. After computing `αs` as above I use them to compute context matrix `C` as follows:

``````using CuArrays

D = 100
T, B = size(αs)
Hs = cu(rand(Float32, D, T, B)) # some 3D input tensor used in computation of context

C = dropdims(sum(reshape(αs, 1, :, B) .* Hs; dims=2); dims=2)
``````

Am I leaving any performance on the table with this implementation of the computation of `C`? In particular, is there way I can speed it up with `TensorCast.jl` and `OMEinsum.jl`?

Would adapting the example given in Better broadcasting section of your package be of benefit here?

Maybe? The speed of the lazy broadcasting has been inconsistent, I don’t know whose fault it is but the doc example is quite broken right now I’m afraid. However OMEinsum does something pretty similar (in cases when it can’t see matrix mult.), which is fast:

``````using TensorCast, OMEinsum, Einsum, BenchmarkTools, Zygote
V = rand(500);
f1(V) = @reduce W[i] := sum(j,k) V[i]*V[j]*V[k]      #   441.332 ms,   953.68 MiB
f2(V) = @reduce W[i] := sum(j,k) V[i]*V[j]*V[k] lazy # 3.559 s,      9.33 GiB
f3(V) = @ein W[i] := V[i]*V[j]*V[k] # OMEinsum,          139.951 ms,     4.25 KiB
f4(V) = @einsum W[i] := V[i]*V[j]*V[k] # Einsum,         140.215 ms,     4.06 KiB
``````

For gradients however it’s not much better than just broadcasting, and something is apparently broken with my approach…

``````V2 = rand(100); # for which these are 1ms forward, but:
@btime gradient(sum∘f1, \$V2); # 253.699 ms (7000176 allocations: 297.56 MiB)
@btime gradient(sum∘f3, \$V2); #   4.477 ms (49 allocations: 6.97 KiB)

f0(V) = dropdims(sum(V .* V' .* reshape(V,1,1,:); dims=(2,3)), dims=(2,3));
@btime gradient(sum∘f0, \$V2);  #  4.993 ms (78 allocations: 30.83 MiB)
``````

Edit: what’s broken is https://github.com/FluxML/Zygote.jl/issues/502

1 Like