I need help improving the performance of gradient calculation of a piece of my program, and understanding why it is performing as it does.

As part of a larger machine learning system, I need a pairwise similarity measure of the columns of two 3d tensors.

My implementation of this is as follows:

```
using Zygote
using Flux: softmax
using LinearAlgebra: dot
mynorm(itr) = sqrt(sum(x->x^2, itr))
cosinesim(u, v) = dot(u, v)/(mynorm(u)*mynorm(v))
function _pairwise!(r::Zygote.Buffer{T, Array{T, 3}},
a::AbstractArray{T, 3},
b::AbstractArray{T, 3}, K::Function) where T
na = size(a, 2)
nb = size(b, 2)
batchsize = size(a, 3)
size(r) == (nb, na, batchsize) || throw(DimensionMismatch("incorrect size of r. Expected $((nb, na, batchsize)), got $(size(r))"))
@inbounds for k = 1:batchsize
@inbounds for j = 1:na
aj = view(a, :, j, k)
@inbounds for i = 1:nb
bi = view(b, :, i, k)
r[i, j, k] = K(bi, aj)
end
end
end
r
end
function pairwise(a::AbstractArray{T, 3}, b::AbstractArray{T, 3}, K=cosinesim) where {T}
W, R, batchsize = size(a)
_, N, _ = size(b)
out = Zygote.Buffer(a, eltype(a), (N, R, batchsize))
_pairwise!(out, a, b, K)
@views for b in 1:batchsize
out[:, :, b] = softmax(out[:, :, b]; dims=1)
end
copy(out)::Array{T, 3}
end
totalsum(a, b) = sum(pairwise(a, b))
```

I tried using Distances.jl at first, but their methods were not directly differentiable, so I implemented my own pairwise!-method using a Buffer to allow array mutation.

Each column also needs to be softmaxed, which of course complicates the implementation a bit.

The following sets up an example with tensors of a typical use size:

```
julia> N, W, R, B = 16, 64, 4, 16
(16, 64, 4, 16)
julia> a = rand(Float32, W, R, B);
julia> b = rand(Float32, W, N, B);
julia> @btime pairwise(a, b);
157.046 μs (1202 allocations: 91.91 KiB)
julia> @btime gradient(totalsum, a, b);
53.937 ms (517684 allocations: 143.89 MiB)
```

While the forward pass seems to perform well, I would expect lower numbers from the gradient, and as this operation is performed multiple times per training iteration, it is a bottleneck.

Looking at output from Juno’s profiler, it seems the gradient is spending a lot of time doing type inference, and indeed the output from @code_warntype shows that the compiler is unable to infer the gradient type.

```
julia> @code_warntype gradient(totalsum, a, b)
Variables
#self#::Core.Compiler.Const(Zygote.gradient, false)
f::Core.Compiler.Const(totalsum, false)
args::Tuple{Array{Float32,3},Array{Float32,3}}
y::Float32
@_5::Int64
back::Zygote.var"#36#37"{typeof(∂(totalsum))}
Body::Tuple{Any,Any}
1 ─ %1 = Core.tuple(f)::Core.Compiler.Const((totalsum,), false)
│ %2 = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Tuple{Float32,Zygote.var"#36#37"{typeof(∂(totalsum))}}
│ %3 = Base.indexed_iterate(%2, 1)::Core.Compiler.PartialStruct(Tuple{Float32,Int64}, Any[Float32, Core.Compiler.Const(2, false)])
│ (y = Core.getfield(%3, 1))
│ (@_5 = Core.getfield(%3, 2))
│ %6 = Base.indexed_iterate(%2, 2, @_5::Core.Compiler.Const(2, false))::Core.Compiler.PartialStruct(Tuple{Zygote.var"#36#37"{typeof(∂(totalsum))},Int64}, Any[Zygote.var"#36#37"{typeof(∂(totalsum))}, Core.Compiler.Const(3, false)])
│ (back = Core.getfield(%6, 1))
│ %8 = Zygote.sensitivity(y)::Core.Compiler.Const(1.0f0, false)
│ %9 = (back)(%8)::Tuple{Any,Any}
└── return %9
```

Is this the reason for the performance degradation, or are the numbers to be expected?

Is there a better way to solve this?

Thank you for helping me understand better!