Taking the gradient of an einsum

Hello, I would really appreciate some help optimizing my optimization. What I calculate makes the most sense (to me) as an einsum but I cannot seem to take its gradient properly. What I managed to do instead is iterate over eachcols (and when applied to the weights, it is an eachcol of a row matrix), but naturally it is slow.

using Flux, CUDA, Tullio, KernelAbstractions, CUDAKernels
using SpecialFunctions: loggamma
using BenchmarkTools

struct mymodel
Flux.@functor mymodel

ll(x::Real, rate::Real) = x*log(rate) - rate - loggamma(x+1)

function loss(loglikelihood)
    y = Flux.onehotbatch(axes(loglikelihood)...) |> gpu
    Flux.logitcrossentropy(loglikelihood, y')

# generate some data
n_i = 300
n_j = 250
n_k = 200

rates = CUDA.rand(n_i, n_k)
xs = CUDA.rand(n_j, n_k)
weights = CUDA.rand(1, n_k)

model = mymodel(weights)

# several implementations of loglikelihood computation
function (model::mymodel)(rates, xs) 
    loglikelihood = zeros(size(xs, 1), size(rates, 1)) |> gpu # using `CUDA.zeros` fails for some reason when taking the gradient. Any idea why?
    for (w, xs_col, rates_col) in zip(eachcol(model.weights), eachcol(xs), eachcol(rates))
        loglikelihood += abs.(w) .* ll.(xs_col, rates_col')

@btime model($rates, $xs); # 1.575 ms (19422 allocations: 2.01 MiB)
@btime Flux.withgradient(model -> loss(model($rates, $xs)), model); # 43.186 ms (451031 allocations: 20.93 MiB)

"somewhat faster"
function (model::mymodel)(rates, xs)
        (xs_col, rates_col, w) -> abs.(w) .* ll.(xs_col, rates_col'), +, 
@btime model($rates, $xs); # 1.191 ms (16988 allocations: 1.06 MiB)
@btime Flux.withgradient(model -> loss(model($rates, $xs)), model); # 29.738 ms (172123 allocations: 8.56 MiB)

"fastest, taking gradient doesn't work"
function (model::mymodel)(rates, xs) 
    @tullio loglikelihood[i, j] := ll(xs[i, k], rates[j, k]) * abs(model.weights[k]) grad=Dual

@btime model($rates, $xs); # 18.067 μs (107 allocations: 4.75 KiB)
@btime Flux.withgradient(model -> loss(model($rates, $xs)), model) # fails. The (shortened) error message is:
# ERROR: InvalidIRError: compiling kernel (...) resulted in invalid LLVM IR
#   Reason: unsupported dynamic function invocation (call to +)

I also tried some other einsum macros (@einsum, @tensor).
Thank you

Are you using CUDA v4? If so I suspect the problem is changes there & to KernelAbstractions. Perhaps fixed in this PR, sorry I haven’t had the bandwidth to look properly.

(This won’t solve your problem, but I believe would be a nice use for nograd=ll if that worked, as this would allow an a symbolic gradient instead of Dual numbers.)

Hi, I am using CUDA v4. Is there an efficient way to implement this without Tullio?

Regarding nograd=ll, I didn’t quite understand what it would do here. I tried both with grad=Dual and without, but naively, as it is just Poisson loglikelihood, I would guess that taking the symbolic gradient here is possible, isn’t it?

You could try with the linked PR for CUDAv4.

An alternative is to use broadcasting, e.g. using TensorCast; @reduce loglikelihood[i, j] := sum(k) ll(xs[i, k], rates[j, k]) * abs(model.weights[k]). This is a horrible algorithm, it allocates the whole N^3 size intermediate, but for expressions like this one with matmul-like access, Tullio also isn’t very smart about making efficient GPU code.

Re nograd, perhaps this is deep in the weeds, but my point is that it’s not necessary to know the derivative of ll to compute the gradient with respect to model.weights, you only need *.

The PR didn’t help it, I got this error “ERROR: Compiling Tuple{CUDA.var”#new_unique#2", Ptr{CUDA.CUctx_st}}: try/c
Refer to the Zygote documentation for fixes."

Using TensorCast was helpful, and without taking the gradient was actually faster than Tullio.

function (model::mymodel)(rates, xs) 
    @reduce loglikelihood[i, j] := sum(k) ll(xs[i, k], rates[j, k]) * abs(model.weights[1, k])
@btime model($rates, $xs); # 6.600 μs (84 allocations: 7.88 KiB)
@btime Flux.withgradient(model -> loss(model($rates, $xs)), $model) # 12.276 ms (1189 allocations: 72.70 KiB)

I was surprised that although running the computation is almost 1000x faster than my naive implementation, taking the gradient was only 2-3 times faster.
Thanks for the help.