Innocent looking optimization of the forward pass causes performance cliff in gradient calculation with Zygote.jl v0.4.7

Consider the following code example:

using LinearAlgebra
using Flux
using BenchmarkTools

function diagmatmul(m, xs)
    ys = m.(xs)
    zs = diag.((first(ys)',) .* ys)
    return sum(sum(zs))
end

function colwise(m, xs)
    ys = m.(xs)
    y₁ = view.(Ref(first(ys)), :, 1:batch_size)
    zs = (y -> y₁ .⋅ view.(Ref(y), :, 1:batch_size)).(ys)
    return sum(sum(zs))
end

function ∇diagmatmul(m, xs, θ)
    gradient(θ) do
        diagmatmul(m, xs)
    end
end

function ∇colwise(m, xs, θ)
    gradient(θ) do
        colwise(m, xs)
    end
end

time = 1:256;
batch_size = 128;
in, out = 64, 32;
m = Dense(in, out)
θ = Flux.params(m);
xs = [rand(Float32, in, batch_size) for _ ∈ time];

Functions diagmatmul and colwise are doing the same thing: they are computing inner product between columns of y₁ and the respective columns of each matrix in ys, but implemented differently. The diagmatmul is more wasteful, since it also computes unnecessary inner products between non-respective columns as well (non–diagonal entries in the result of matmultiplication), so colwise should be faster. However, this is not true at all when it comes to computing their gradients:

julia> @benchmark diagmatmul($m, $xs)
BenchmarkTools.Trial: 
  memory estimate:  24.39 MiB
  allocs estimate:  1547
  --------------
  minimum time:     11.103 ms (0.00% GC)
  median time:      12.422 ms (0.00% GC)
  mean time:        15.586 ms (21.91% GC)
  maximum time:     30.666 ms (62.17% GC)
  --------------
  samples:          321
  evals/sample:     1

julia> @benchmark colwise($m, $xs)
BenchmarkTools.Trial: 
  memory estimate:  9.93 MiB
  allocs estimate:  36243
  --------------
  minimum time:     3.370 ms (0.00% GC)
  median time:      3.762 ms (0.00% GC)
  mean time:        5.116 ms (25.37% GC)
  maximum time:     24.802 ms (78.08% GC)
  --------------
  samples:          978
  evals/sample:     1

julia> @benchmark ∇diagmatmul($m, $xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  51.01 MiB
  allocs estimate:  12981
  --------------
  minimum time:     20.382 ms (0.00% GC)
  median time:      21.960 ms (0.00% GC)
  mean time:        29.627 ms (27.53% GC)
  maximum time:     47.773 ms (50.93% GC)
  --------------
  samples:          169
  evals/sample:     1

julia> @benchmark ∇colwise($m, $xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  1.09 GiB
  allocs estimate:  974619
  --------------
  minimum time:     456.830 ms (34.88% GC)
  median time:      477.833 ms (33.95% GC)
  mean time:        511.531 ms (39.05% GC)
  maximum time:     835.988 ms (62.42% GC)
  --------------
  samples:          10
  evals/sample:     1

So colwise is ~3.3 times faster than the diagmatmul, yet it’s gradient is more than 22 slower than that of diagmatmul. So my questions are:

  • What is causing this problem?
  • How to make the gradient of colwise as efficient as it’s forward pass?

For context, the code in diagmatmul was part of my forward pass function in my project. I decided to optimize it by rewriting it as in colwise and saw increase in speed as expected, but then I got huuuge regressions in optimized version during backward pass.

I’m not surprised that the one with many view(...)s is slow. In the backward pass, each of these makes a new dense array of zeros, and writes into it. It’s possible to avoid that, but perhaps unnecessary.

Here are a few variants, the fastest of which just keeps everything as giant arrays.

function diagmatmul(m, xs) # just adding comments to understand!
    ys = m.(xs)                      # ys[t][o,b]
    zs = diag.((first(ys)',) .* ys)  # zs[t][b] := sum(o) ys[1][o,b] * ys[t][o,b] 
    return sum(sum(zs))              # sum(t,b)
end

function usedot(m, xs)
    ys = m.(xs)                     # ys[t][o,b]
    z = dot.(Ref(first(ys)), ys)    # z[t] := sum(o,b) ys[1][o,b] * ys[t][o,b] 
    return sum(z)
end

using TensorCast

@cast xt[i,b\t] := xs[t][i,b]; # glue & reshape

function gluedup(m, xt)
    yt = m(xt)
    @cast ys[o,b,t] := yt[o,b\t]  b:batch_size; # just reshape
    @reduce sum(o,b,t) ys[o,b,1] * ys[o,b,t]    # broadcast & sum
end

From which I get the following times:

@btime ∇diagmatmul($m, $xs, $θ) #  25.432 ms,     51.01 MiB
@btime ∇colwise($m, $xs, $θ)    # 385.081 ms,  1.09 GiB
@btime ∇usedot($m, $xs, $θ)     #  12.438 ms,     34.65 MiB
@btime ∇gluedup($m, $xt, $θ)    #   9.640 ms,     36.03 MiB

It ought to be possible to do better, as sum(ys[:,:,1] * ys) is broadcasting out a big array and then immediately summing it. But still an improvement!

3 Likes

Thank you @mcabbott, this is really helpful. However, it looks like I oversimplified my minimum working example: the double sum in sum(sum(zs)) is there just to scalarize the output, so that I can take the gradient of that particular piece from the larger forward pass. In the actual code, zs from above are stacked into a matrix (using your awesome package), which is then passed through softmax, so like:

function diagmatmul(m, xs)
    ys = m.(xs)
    zs = diag.((first(ys)',) .* ys)
    αs = softmax(stack(zs)')
end

function ∇diagmatmul(m, xs, θ)
    gradient(θ) do
        sum(sum(diagmatmul(m, xs))
    end
end
function colwise(m, xs)
    ys = m.(xs)
    y₁ = view.(Ref(first(ys)), :, 1:batch_size)
    zs = (y -> y₁ .⋅ view.(Ref(y), :, 1:batch_size)).(ys)
    αs = softmax(stack(zs)')
end

function ∇colwise(m, xs, θ)
    gradient(θ) do
        sum(sum(colwise(m, xs)))
    end
end

(there are even more steps after that in the forward pass function, but they are not relevant to the issue described here). So usedot and gluedup functions won’t solve the problem.

OK, then usedot is useless, but one big array still has a small edge…

using LazyStack, OMEinsum#master

function diagmatmul(m, xs)
    ys = m.(xs);                     # ys[t][o,b]
    zs = diag.((first(ys)',) .* ys); # zs[t][b] := sum(o) ys[1][o,b] * ys[t][o,b] 
    z = stack(zs)';                  # z[t,b] := zs[t][b]
    αs = softmax(z);
end

function gluedup(m, xt)
    yt = m(xt);
    @cast ys[o,t,b] := yt[o,b\t]  b:batch_size, nolazy;
    ys1 = ys[:,1,:];
    @ein z[t,b] := ys1[o,b] * ys[o,t,b]; # batch matmul
    αs_ = softmax(z);
end

αs[1:2, 1:2]
αs_[1:2, 1:2]

@btime ∇diagmatmul($m, $xs, $θ) #  27.626 ms
@btime ∇colwise($m, $xs, $θ)    # 401.620 ms
@btime ∇gluedup($m, $xt, $θ)    #  14.256 ms

IIRC the batch matmul which this is calling still makes slices internally, but along b, at least on the CPU. It pays to make this the last index.

1 Like

Thanks a lot @mcabbott I ended up using your gluedup version and it improved the performance by quite a bit.
Now that I see all the amazing stuff TensorCast.jl and OMEinsum.jl and the likes can do, I have a follow-up question. After computing αs as above I use them to compute context matrix C as follows:

using CuArrays

D = 100
T, B = size(αs)
Hs = cu(rand(Float32, D, T, B)) # some 3D input tensor used in computation of context

C = dropdims(sum(reshape(αs, 1, :, B) .* Hs; dims=2); dims=2)

Am I leaving any performance on the table with this implementation of the computation of C? In particular, is there way I can speed it up with TensorCast.jl and OMEinsum.jl?

Would adapting the example given in Better broadcasting section of your package be of benefit here?

Maybe? The speed of the lazy broadcasting has been inconsistent, I don’t know whose fault it is but the doc example is quite broken right now I’m afraid. However OMEinsum does something pretty similar (in cases when it can’t see matrix mult.), which is fast:

using TensorCast, OMEinsum, Einsum, BenchmarkTools, Zygote
V = rand(500);
f1(V) = @reduce W[i] := sum(j,k) V[i]*V[j]*V[k]      #   441.332 ms,   953.68 MiB
f2(V) = @reduce W[i] := sum(j,k) V[i]*V[j]*V[k] lazy # 3.559 s,      9.33 GiB
f3(V) = @ein W[i] := V[i]*V[j]*V[k] # OMEinsum,          139.951 ms,     4.25 KiB
f4(V) = @einsum W[i] := V[i]*V[j]*V[k] # Einsum,         140.215 ms,     4.06 KiB

For gradients however it’s not much better than just broadcasting, and something is apparently broken with my approach…

V2 = rand(100); # for which these are 1ms forward, but:
@btime gradient(sum∘f1, $V2); # 253.699 ms (7000176 allocations: 297.56 MiB)
@btime gradient(sum∘f3, $V2); #   4.477 ms (49 allocations: 6.97 KiB)

f0(V) = dropdims(sum(V .* V' .* reshape(V,1,1,:); dims=(2,3)), dims=(2,3));
@btime gradient(sum∘f0, $V2);  #  4.993 ms (78 allocations: 30.83 MiB)

Edit: what’s broken is https://github.com/FluxML/Zygote.jl/issues/502

1 Like