nevermind sorry, a loss should be define like this
function loss_fn(model,ps,st,d)
x,y = d
ŷ,stn = model(x,ps,st)
return Lux.MSELoss()(ŷ,y),stn,(;)
end
nevermind sorry, a loss should be define like this
function loss_fn(model,ps,st,d)
x,y = d
ŷ,stn = model(x,ps,st)
return Lux.MSELoss()(ŷ,y),stn,(;)
end