I need to do mapslices but need it to support gradient tracking for Zygote. So, I’m trying to use SliceMap.slicemap. I can’t use SliceMap.mapcols because I have parameters inside the function. SliceMap.slicemap returns a JuliennedArrays.Align. sum works on a CuArray, but on a JuliennedArrays.Align, it gives a scalar indexing error. Is there some solution to this?
Example (this results in scalar index error):
using CUDA
import SliceMap, Zygote, Flux
CUDA.allowscalar(false)
function func(m, k, data)
y = SliceMap.slicemap(m; dims=1) do col
col .* k
end
return sum(y) # I think this line gives the scalar indexing error
end
function test_gpu()
m = fill(1.0, 3, 10) |> Flux.gpu
data = fill(2.0, 3, 10) |> Flux.gpu
k = [0.7, 0.5, 0.5] |> Flux.gpu
Zygote.gradient((m, k) -> func(m, k, data), m, k)
end