I try following the LUX tutorial on fitting an ML model using Optimization.jl for fitting a hybrid model, where the LUX model predicts some parameters of a process-based model, the LUX model application is just a call in a larger model, and the LUX parameters are only a subset of the overall parameter vector.
Initially, not close to the solution, there are NaNs in the gradient for some of the minibatches.
When coding the training loop for pure LUX model as e.g. in this tutorial myself, I can just skip the update for those minibatches where the gradient contains any NaN.
if any(isnan.(grads))
println("Skipped NaN : Batch $i")
else
Optimisers.update!(opt_st_new, ps, grads)
end
How do I tell Optimization.jl to skip parameter updates during these minibatches, when using the solve
method instead of the training-loop?
The following MWE demonstrates the problem without using any LUX model. After encountering the minibatch with NaNs, all the subsequent updates result in a loss value of NaN and there is no convergence to the optimum.
using Optimization
using OptimizationOptimisers
using MLUtils
import Zygote
d = fill(1.0, 100)
d[42:43] .= NaN
dl = DataLoader(d, batchsize=10)
callback_loss = (moditer) -> let iter = 1, moditer = moditer
function (state, l)
if iter % moditer == 1
println("$iter, $l")
end
iter = iter + 1
return false
end
end
optf = Optimization.OptimizationFunction((x,d) -> sum(d .* abs2.(x)),
Optimization.AutoZygote())
optprob = OptimizationProblem(optf, [2.0], dl)
alg = AdaMax(0.9)
#alg = Adam(0.9)
res = Optimization.solve(optprob, alg, epochs=6,
callback = callback_loss(2),
)