Custom loss functions in `Lux.jl`

Is calling MSELoss() inside the function a good idea or better move it outside loss_function?

Either is fine. If you are moving it to a global scope remember to annotate it with const

For L2 regression term on model weights, it will be more efficient to do API · Optimisers.jl as part of your optimizer

loss_reg = zero(T)
for p in ps
loss_reg += sum(abs2, Base.Flatten(p))
end

This won’t work in the general case. If you really want to use this approach (instead of the optimizers way), use Functors.fleaves(ps) and then iterate over it. But this method is type unstable.

2 Likes