SciML Optimization - algorithms and convergence speed

I’m currently working on a minimization problem where the loss function is of dimensions n by p, with both n and p ~100 (but not necessarily equal, although in some cases they may be). I started with a naive implementation of the loss function (which made it nice to read) and gradually moved to an optimized version (following https://docs.julialang.org/en/v1/manual/performance-tips/ and https://book.sciml.ai/notes/02-Optimizing_Serial_Code/), leading to of 100x-1000x speedup in my loss function over the naive implementation.

I settled on SciML w/ NLopt.LN_NEWUOA() or Optim.NelderMead(), roughly tied for fastest in my attempts. I get a decent convergence after 50k iterations in ~100s (on hardware half as fast as say an i9 13900 for single thread performance).

Knowing that:

  • The loss function is somewhat complex to compute, with rather slow AD
  • But it is also rather well behaved and smooth with no local minima, at least no horrible ones
  • Its Jacobian would be close to upper triangular
  • I’m usually able to supply a pretty good initial guess

My questions is: is there an algorithm that could be better suited, or are there any other areas that can be explored with the hope of speeding up convergence?

What I’ve attempted:

  • Several other algos in Optim and NLopt, but not all of them
  • Growing the problem from (i,j) to (n,p), ie solving problems of incrementally growing sizes (given the upper triangular Jacobian), although the problem setup seems to come at a cost that outweighs the benefit of this approach.

Any hint appreciated.

EDIT: Something I thought of but haven’t tried, is to try to put it all in matrix form. Issue I have with this approach is the resulting matrices would be of sizeable dimensions (say 3 digits by 5 digits dimensions) and with a decent amount of zeros, hence unclear it would speed things up.

What does that mean? For minimization, the loss function must ultimately be a scalar, and its derivative must therefore be a gradient vector, not a matrix.

1 Like

You are correct, I take a sum(abs2.()) of the p dim output as the final scalar loss function. However the intermediary step is that I have a p-dimensional vector holding calibration errors plus a penalty function.

The fact that you ultimately have a scalar output is pretty crucial. It means you only need a gradient, not a whole Jacobian matrix (which you need never compute explicitly), and it means that you can (in theory) efficiently compute that gradient by backpropagation/reverse-mode/adjoint methods (with a cost similar to that of computing your loss function).

If you care about performance and you have \sim 100 parameters, I would try pretty hard to compute the gradient analytically, either by reverse-mode AD or manually by adjoint methods, or using a mixture of the two (i.e., AD + some manual rrule’s for functions that AD struggles on). An analytical gradient gives you access to much more efficient local-optimization algorithms than derivative-free algorithms like NEWUOA (I would recommend BOBYQA instead) or Nelder–Mead (which I would generally view as an obsolete method these days).

2 Likes

Thank you for the pointers. BOBYQA gives a lower loss() for a given number of iterations, but is a touch slower per iteration it seems. Maybe a function of my specific minimization problem.

Regarding gradients, I had experimented with:

  • ForwardDiff. Making my loss function generic enough to accept Real comes at a cost though, compared to previous iterations that were explicitly Float64 (or Float32). The 15x slowdown negates any improvement I could pickup from computing a gradient. I have to check if the loss has too much genericity or just the required amount though.

Following your suggestion, I am in the process of trying:

  • Zygote. Right away I stumble on forbidden array mutations, which were key to the speed of my non-AD Float64 function. It seems that I’m again on the path of forgoing careful optimizations for AD compatibility that may or may not yield overall improvements. Although I need to be humble and recognize my inability to write a function without array mutations (that is as fast as the one with them) doesn’t mean that it is impossible, but merely outside of my immediate reach.

In any case, I will keep trying. Thank you for the pointers.

If this gives a slowdown, perhaps you are doing it wrong? i.e. if you make your functions generic enough to accept Real, but then pass Float64 inputs, if you did it right it should be exactly the same speed (and, in fact, compile to exactly the same binary code) as when you explicitly declared things as Float64.

(That being said, using ForwardDiff to compute 100 derivatives, i.e. passing dual numbers, will be a lot slower than calling it for Float64 arguments — for forward-mode AD, the cost is linearly proportional to the number of inputs you are differentiating with respect to. That’s why you want to use reverse-mode/adjoint/backpropagation differentiation, which in theory gives only \sim 2\times slowdown to get a scalar loss-function value and all derivatives.)

1 Like

You are most likely correct in that I must be doing it wrong.

And point taken re fwd vs reverse AD. Thank you Sir.

If you can point us to your code or a typical example we might be able to help further

I was able to address the previous issue (“Making my loss function generic enough to accept Real comes at a cost though”) by using parametric types, and now the generic function has identical performance to the Float64 one.

Using Optimization.AutoReverseDiff() I got some results that were comparable to my previous best. The AD allows to converge to roughly the same loss in a similar amount of time, only difference is it takes only 25 iterations with AD vs 20k iterations without.

1 Like

Hi @sob ,

Do you have any constraints in your problem?
From my experience, it is very often worth using a solver that exploits the structure of the problem. Your loss function leads to nonlinear least squares problems.

If you have an unconstrained problem, it’s worth trying TRUNK from JSOSolvers.jl ( Introduction to JSOSolvers ).

My problem is unconstrained. I’ll take a look, thanks for the pointer.