I need to do mapslices but need it to support gradient tracking for Zygote. So, I’m trying to use SliceMap.slicemap. I can’t use SliceMap.mapcols because I have parameters inside the function. SliceMap.slicemap returns a JuliennedArrays.Align. sum works on a CuArray, but on a JuliennedArrays.Align, it gives a scalar indexing error. Is there some solution to this?
Example (this results in scalar index error):
using CUDA
import SliceMap, Zygote, Flux
CUDA.allowscalar(false)
function func(m, k, data)
    y = SliceMap.slicemap(m; dims=1) do col
        col .* k
    end
    return sum(y) # I think this line gives the scalar indexing error
end
function test_gpu()
    m = fill(1.0, 3, 10) |> Flux.gpu
    data = fill(2.0, 3, 10) |> Flux.gpu
    k = [0.7, 0.5, 0.5] |> Flux.gpu
    Zygote.gradient((m, k) -> func(m, k, data), m, k)
end