I would like to train a sparse neural network using Flux. I can train it in CPU, but I would like to do it using GPU, profiting from CuSparseMatrixCSC
. However, when I try to run the code below I get ERROR: LoadError: This object is not a GPU array
.
Is it possible to train sparse networks in GPU?
If so, what is causing the error?
Thanks!
Code:
using CUDA, Flux, SparseArrays, MLDatasets
x_train, y_train = MLDatasets.MNIST.traindata(Float32)
x_train = Flux.flatten(x_train) |> gpu
y_train = Flux.onehotbatch(y_train, 0:9) |> gpu
data = Flux.Data.DataLoader((x_train, y_train), batchsize=256, shuffle=true)
model = Dense(sprand(10, 784, 1.0), zeros(Float32, 10)) |> gpu
opt = ADAM(3e-4)
loss(x, y) = Flux.Losses.logitcrossentropy(model(x), y)
parameters = Flux.params(model)
for (x, y) ∈ data
gradients = gradient(() -> loss(x, y), parameters)
Flux.Optimise.update!(opt, parameters, gradients)
end
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] backend(#unused#::Type)
@ GPUArrays ~/.julia/packages/GPUArrays/Zecv7/src/device/execution.jl:15
[3] backend(x::Base.ReshapedArray{Float32, 1, CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}})
@ GPUArrays ~/.julia/packages/GPUArrays/Zecv7/src/device/execution.jl:16
[4] _copyto!
@ ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:73 [inlined]
[5] materialize!
@ ~/.julia/packages/GPUArrays/Zecv7/src/host/broadcast.jl:51 [inlined]
[6] materialize!
@ ./broadcast.jl:868 [inlined]
[7] materialize!
@ ./broadcast.jl:864 [inlined]
[8] restructure(x::CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}, y::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ ArrayInterfaceCore ~/.julia/packages/ArrayInterfaceCore/wwYvJ/src/ArrayInterfaceCore.jl:346
[9] update!(opt::ADAM, x::CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32}, x̄::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Flux.Optimise ~/.julia/packages/Flux/js6mP/src/optimise/train.jl:16
[10] update!(opt::ADAM, xs::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, gs::Zygote.Grads)
@ Flux.Optimise ~/.julia/packages/Flux/js6mP/src/optimise/train.jl:24
[11] top-level scope
@ .../test_sparse.jl:16
in expression starting at .../test_sparse.jl:14