The errors seems to be not wrapping w
in params
. So this works
using Flux
using CuArrays
CuArrays.allowscalar(false)
x = gpu(rand(Float32, 1_000_000, 2))
y = x*gpu([2, 2]) + gpu(rand(Float32, 1_000_000))
w = gpu(rand(Float64, 2, 1))
loss(x, y) = Flux.mse(x*w,y) |> gpu
Flux.train!(loss, params(w), ((x,y),), ADAM())