Hello, I would really appreciate some help optimizing my optimization. What I calculate makes the most sense (to me) as an `einsum`

but I cannot seem to take its gradient properly. What I managed to do instead is iterate over `eachcol`

s (and when applied to the `weights`

, it is an `eachcol`

of a row matrix), but naturally it is slow.

```
using Flux, CUDA, Tullio, KernelAbstractions, CUDAKernels
using SpecialFunctions: loggamma
using BenchmarkTools
CUDA.allowscalar(false)
struct mymodel
weights
end
Flux.@functor mymodel
ll(x::Real, rate::Real) = x*log(rate) - rate - loggamma(x+1)
function loss(loglikelihood)
y = Flux.onehotbatch(axes(loglikelihood)...) |> gpu
Flux.logitcrossentropy(loglikelihood, y')
end
# generate some data
n_i = 300
n_j = 250
n_k = 200
rates = CUDA.rand(n_i, n_k)
xs = CUDA.rand(n_j, n_k)
weights = CUDA.rand(1, n_k)
model = mymodel(weights)
# several implementations of loglikelihood computation
"slowest"
function (model::mymodel)(rates, xs)
loglikelihood = zeros(size(xs, 1), size(rates, 1)) |> gpu # using `CUDA.zeros` fails for some reason when taking the gradient. Any idea why?
for (w, xs_col, rates_col) in zip(eachcol(model.weights), eachcol(xs), eachcol(rates))
loglikelihood += abs.(w) .* ll.(xs_col, rates_col')
end
loglikelihood
end
@btime model($rates, $xs); # 1.575 ms (19422 allocations: 2.01 MiB)
@btime Flux.withgradient(model -> loss(model($rates, $xs)), model); # 43.186 ms (451031 allocations: 20.93 MiB)
"somewhat faster"
function (model::mymodel)(rates, xs)
mapreduce(
(xs_col, rates_col, w) -> abs.(w) .* ll.(xs_col, rates_col'), +,
eachcol(xs),
eachcol(rates),
eachcol(model.weights)
)
end
@btime model($rates, $xs); # 1.191 ms (16988 allocations: 1.06 MiB)
@btime Flux.withgradient(model -> loss(model($rates, $xs)), model); # 29.738 ms (172123 allocations: 8.56 MiB)
"fastest, taking gradient doesn't work"
function (model::mymodel)(rates, xs)
@tullio loglikelihood[i, j] := ll(xs[i, k], rates[j, k]) * abs(model.weights[k]) grad=Dual
end
@btime model($rates, $xs); # 18.067 μs (107 allocations: 4.75 KiB)
@btime Flux.withgradient(model -> loss(model($rates, $xs)), model) # fails. The (shortened) error message is:
# ERROR: InvalidIRError: compiling kernel (...) resulted in invalid LLVM IR
# Reason: unsupported dynamic function invocation (call to +)
```

I also tried some other einsum macros (`@einsum`

, `@tensor`

).

Thank you