Efficient GPU dense-sparse matmul differentiable with Zygote

Turns out I the issue is related to an optimisation of the gradients of sum. If I change sum((A * x')') to sum((A * x')' * CuMatrix(rand(n, m)) it works fine.

I’m unsure if that optimisation is helpful on GPU or if it should be considered a bug, but it’s narrow enough that it only really affected my minimal working example (the motivating problem I had featured a different indexerror that I’ve since tracked down).

1 Like