What's wrong with this Flux model definitin?

The errors seems to be not wrapping w in params. So this works

using Flux
using CuArrays
CuArrays.allowscalar(false)

x = gpu(rand(Float32, 1_000_000, 2))
y = x*gpu([2, 2]) + gpu(rand(Float32, 1_000_000))

w = gpu(rand(Float64, 2, 1))
loss(x, y) = Flux.mse(x*w,y) |> gpu

Flux.train!(loss, params(w), ((x,y),), ADAM())
1 Like