Should NonlinearLeastSquaresProblem be used for deep learning?

From this JuliaCon talk Fast and Robust Least Squares / Curve Fitting by @ChrisRackauckas, I gathered that if you have a L2 loss function then you should be using NonlinearLeastSquaresProblem from the NonlinearSolve.jl package.

Indeed it should be faster and more numerically stable than other optimization methods not using the full Jacobian information. Is my understunding correct?

Hence my question: can we use NonlinearLeastSquaresProblem for training neural networks for regression tasks where the loss is a MSE? In other words, can we use NonlinearLeastSquaresProblem instead of OptimizationProblem from Optimization.jl for training neural networks with a MSE loss?

That is to say, exit Adam, SGD, LBFGS and just use Levenberg-Marquardt and Gauss-Newton?

In my use case I have a neural ODE model where the loss is the MSE between the predicted and observed time series.

On this exemple from the documentation it is indeed faster by x6 and has better objective values.

On my real cases with NeuralODEs, I am not sure yet to see such improvements, but I am still testing and trying to understand the best way to use it (Levenberg-Marquardt, Gauss-Newton, TrustRegion?). I have a few coupled ODEs (1 to 5) with small Neural Networks (~400 parameters).

I have some questions:

  • Can one use a callback function to compute metrics during the optimization and save parameters? Is it possible to do so with NonlinearSolve? From the documentation it is not clear.

  • Do you have recommendations for neuralODE specifically? Is it better to use Levenberg-Marquardt, Gauss-Newton or TrustRegion? And autodiff method? I am currently using the same as the tutorial.

Yes it is definitely underused for this kind of thing.

Not at this time, but we should add a form of callbacks.

It can highly depend on the many things. But TrustRegion is generally pretty good, and the right autodiff depends on the structure of the Jacobian: if it’s wide then you want reverse mode, tall you want forward mode. So it’s the number of data points vs the number of states, where data points >> states means you want reverse mode otherwise forward mode.