Help to improve performance of gradient calculation on tensor operations

I need help improving the performance of gradient calculation of a piece of my program, and understanding why it is performing as it does.

As part of a larger machine learning system, I need a pairwise similarity measure of the columns of two 3d tensors.

My implementation of this is as follows:

using Zygote
using Flux: softmax
using LinearAlgebra: dot

mynorm(itr) = sqrt(sum(x->x^2, itr))

cosinesim(u, v) = dot(u, v)/(mynorm(u)*mynorm(v))


function _pairwise!(r::Zygote.Buffer{T, Array{T, 3}},
                    a::AbstractArray{T, 3},
                    b::AbstractArray{T, 3}, K::Function) where T
    na = size(a, 2)
    nb = size(b, 2)
    batchsize = size(a, 3)
    size(r) == (nb, na, batchsize) || throw(DimensionMismatch("incorrect size of r. Expected $((nb, na, batchsize)), got $(size(r))"))
    @inbounds for k = 1:batchsize
        @inbounds for j = 1:na
            aj = view(a, :, j, k)
            @inbounds for i = 1:nb
                bi = view(b, :, i, k)
                r[i, j, k] = K(bi, aj)
            end
        end
    end
    r
end


function pairwise(a::AbstractArray{T, 3}, b::AbstractArray{T, 3}, K=cosinesim) where {T}
    W, R, batchsize = size(a)
    _, N, _ = size(b)
    out = Zygote.Buffer(a, eltype(a), (N, R, batchsize))
    _pairwise!(out, a, b, K)
    @views for b in 1:batchsize
        out[:, :, b] = softmax(out[:, :, b]; dims=1)
    end
    copy(out)::Array{T, 3}
end

totalsum(a, b) = sum(pairwise(a, b))

I tried using Distances.jl at first, but their methods were not directly differentiable, so I implemented my own pairwise!-method using a Buffer to allow array mutation.
Each column also needs to be softmaxed, which of course complicates the implementation a bit.

The following sets up an example with tensors of a typical use size:

julia> N, W, R, B = 16, 64, 4, 16
(16, 64, 4, 16)

julia> a = rand(Float32, W, R, B);

julia> b = rand(Float32, W, N, B);

julia> @btime pairwise(a, b);
  157.046 μs (1202 allocations: 91.91 KiB)

julia> @btime gradient(totalsum, a, b);
  53.937 ms (517684 allocations: 143.89 MiB)

While the forward pass seems to perform well, I would expect lower numbers from the gradient, and as this operation is performed multiple times per training iteration, it is a bottleneck.

Looking at output from Juno’s profiler, it seems the gradient is spending a lot of time doing type inference, and indeed the output from @code_warntype shows that the compiler is unable to infer the gradient type.

julia> @code_warntype gradient(totalsum, a, b)
Variables
  #self#::Core.Compiler.Const(Zygote.gradient, false)
  f::Core.Compiler.Const(totalsum, false)
  args::Tuple{Array{Float32,3},Array{Float32,3}}
  y::Float32
  @_5::Int64
  back::Zygote.var"#36#37"{typeof(∂(totalsum))}

Body::Tuple{Any,Any}
1 ─ %1 = Core.tuple(f)::Core.Compiler.Const((totalsum,), false)
│   %2 = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Tuple{Float32,Zygote.var"#36#37"{typeof(∂(totalsum))}}
│   %3 = Base.indexed_iterate(%2, 1)::Core.Compiler.PartialStruct(Tuple{Float32,Int64}, Any[Float32, Core.Compiler.Const(2, false)])
│        (y = Core.getfield(%3, 1))
│        (@_5 = Core.getfield(%3, 2))
│   %6 = Base.indexed_iterate(%2, 2, @_5::Core.Compiler.Const(2, false))::Core.Compiler.PartialStruct(Tuple{Zygote.var"#36#37"{typeof(∂(totalsum))},Int64}, Any[Zygote.var"#36#37"{typeof(∂(totalsum))}, Core.Compiler.Const(3, false)])
│        (back = Core.getfield(%6, 1))
│   %8 = Zygote.sensitivity(y)::Core.Compiler.Const(1.0f0, false)
│   %9 = (back)(%8)::Tuple{Any,Any}
└──      return %9

Is this the reason for the performance degradation, or are the numbers to be expected?
Is there a better way to solve this?
Thank you for helping me understand better!

If you must make slices, then I think using SliceMap (+ JuliennedArrays) or TensorCast will usually be quicker than writing into a Buffer.

But can’t you just operate on the whole arrays here? I think you have written this:

r[i, j, k] = sum(s) a[s, j, k] * b[s, i, k] / sqrt(sum(s') a[s', j, k]^2) / sqrt(sum(s'') b[s'', i, k]^2)

in which the major operation is just batched matrix multiplication (with a ‘T’), multiplied by some normalisation factors (done perhaps by broadcasting).

1 Like

Thank you for pointing me in the right direction here.
I used TensorCast in the following implementation:

using TensorCast
function cosinesim(a, b)
    @reduce similarity[i, j, k] := sum(s) a[s, j, k] * b[s, i, k] /
        sqrt( @reduce [_, j, k] := sum(s') a[s', j, k]^2) /
        sqrt( @reduce [_, i, k] := sum(s'') b[s'', i, k]^2)
end

function mysoftmax(a)
    @cast submax[i, j, k] := a[i, j, k] - @reduce [_, j, k] := maximum(i) a[i, j, k]
    @cast r[i, j, k] := exp(submax[i, j, k]) / @reduce [_, j, k] := sum(i) exp(submax[i, j, k])
end

pairwise2(a) = mysoftmax(cosinesim(a))

The speedup is great:

julia> @btime pairwise2(a, b);
  699.095 μs (124 allocations: 439.31 KiB)

julia> @btime gradient((a, b)-> sum(pairwise2(a, b)), a, b);
  1.348 ms (7818 allocations: 3.70 MiB)

Thank you for your help!

Great! I think it can go even faster though… here’s what I had in mind about batched_mul, although sadly you will hit this issue:

using Zygote, TensorCast, NNlib#fix2  branch from https://github.com/FluxML/NNlib.jl/pull/191

function cosine_nnlib(a, b)
    @reduce den1[j, k] := sum(s) a[s, j, k]^2 
    @reduce den2[i, k] := sum(s) b[s, i, k]^2 
    bmm = batched_mul(PermutedDimsArray(b, (2,1,3)), a)
    @cast similarity[i, j, k] := bmm[i, j, k] / sqrt(den1[j, k] * den2[i, k])
end

@btime gradient((a,b) -> sum((x->x).(cosinesim(a, b))), $a, $b);  #  1.190 ms (4556 allocations: 3.58 MiB)
@btime gradient((a,b) -> sum((x->x).(cosine_nnlib(a, b))), $a, $b); #  232.991 μs (6375 allocations: 579.78 KiB)

This is actually a neat test case for something else I was working on. I can get the gradient of the batched_mul step from 400 μs (TensorCast) or 100 μs (NNlib) down to 7 μs, about the same as the forward pass. (With LoopVectorization doing the heavy lifting!) It should be possible to fuse that with the bmm / sqrt(...) step, but right now this works slowly or not at all, so I can’t improve much on this NNlib answer.

Also worth knowing: NNlib’s softmax does pretty much the broadcasting you wrote, but it has a hand-written gradient, which is much quicker here:

mysoftmax(a) ≈ softmax(a, dims=1)

@btime mysoftmax($a);       # 76.286 μs
@btime softmax($a, dims=1); # 45.538 μs

@btime gradient(a -> sum(mysoftmax(a)), $a);      # 640.619 μs
@btime gradient(a -> sum(softmax(a, dims=1)), $a); # 95.622 μs
1 Like