A hacky guide to using automatic differentiation in nested optimization problems

Unconstrained Optimization

Automatic differentiation is amazing. But models in my field, economics, often include nested optimization problems. A simple example might look like this, where x is a vector or scalar.

y = \min_x f(x) \text{, where } f(x) \equiv \min_z h(x, z)

Unfortunately, when part of your objective function f is computed as a optimum, straightforward forward differentiation doesn’t work! This is where the envelope theorem comes in. By the envelope theorem,

\frac{d}{dx}(\min_z h(x,z)) = \frac{\partial}{\partial x}h(x,\text{arg}\min_zh(x,z)).

so that

\frac{d}{dx}f(x) = \frac{\partial}{\partial x}h(x,z^*(x))

where = z^*(x) = \text{arg}\min_z h(x,z).

So how do we connect this back to automatic differentiation? I’ll be using Optim.jl for this. With each call of the objective function f(x), we first compute z^*(x) “offline” (i.e. not within the machinery of the forwarddiff), then call h(x, z^*(x) using that z^*(x), and the forwarddiff ends up computing the correct derivative. Here’s an example of what I mean:

using Optim
using ForwardDiff

function h(xz)
return xz[1]^2 + xz[2]^2 + (xz[1]*xz[2])^2
end

function f(x)
# Compute the minimizer (outside of the autodiff framework)
h_offline(z) = h(vcat(ForwardDiff.value.(x),z))
res = optimize(h_offline, [1.0], LBFGS(); autodiff=:forward)
z_star = Optim.minimizer(res)
return h(vcat(x,z_star))
end

optimize(f, [1.0], LBFGS(); autodiff=:forward)


What’s going on here? In the objective function f, we first define an objective function h_offline which strips all the autodiff stuff from x using value.(x), then calls h. We then use optimize to find the minimizer z_star = z^*(x). Then we reenter autodiff-land, and call the objective function h using x and the minimizer z_star.

Note: ForwardDiff.jl works by replacing the input vector x::Vector{Float64} with an input vector x::Vector{Dual}, where Dual is a data type that keeps track of its value, just like a number, as well as information about partial derivatives. The call value.(x) recovers the original vector x::Vector{Float64} of simple Floats.

I also want to mention that this works with any combination of nested optimization loops. Several nested levels? Fine. Multiple z values separately computed as optima in a single f function? Fine. This procedure simply allows the correct derivative information to be passed on from a step computed using an optimization.

Constrained Optimization

So far, everything I’ve written has covered unconstrained optimization. Luckily, the “full” envelope theorem helps us with that as well. Suppose now that

\begin{aligned} f(x) &\equiv \min_z h(x,z)\\ \text{s.t}&\\ g(x,z) &= 0 \end{aligned}

and that you can solve for the minimizer of the above problem z^*(x) using your favorite constrained optimization library. Now, one more step is involved to get the derivative of f(x). First, we can rewrite the problem using a “Langrangian function”

Now the envelope theorem gives us that

\begin{aligned} \frac{d}{dx} \min_{z,\lambda} \mathcal{L}(z,\lambda,x) &= \frac{\partial}{\partial x}\mathcal{L}(z^*(x), \lambda^*(x), x)\\ \text{so that}\quad\frac{d}{dx}f(x) &= \frac{\partial}{\partial x}h(x, z^*(x)) - \lambda^*(x)\frac{\partial}{\partial x}g(x, z^*(x)). \end{aligned}

where z^*(x) and \lambda^*(x) are the solution to \min_{z,\lambda}\mathcal{L}(z,\lambda,x). If you have z^*(x), then you can recover \lambda(x) using the property that z^*(x) minimizes the Lagrangian:

\begin{aligned} \frac{\partial}{\partial z}\mathcal{L}(z^*(x), \lambda^*(x), x) &= 0\\ \frac{\partial}{\partial z}h(x, z^*(x)) &= \lambda^*(x)\frac{\partial}{\partial z}g(x, z^*(x))\\ \lambda^*(x) &= \frac{\frac{\partial}{\partial z}h(x, z^*(x))}{\frac{\partial}{\partial z}g(x, z^*(x))} \end{aligned}

Now, the problem can be solved just as we did before. Compute z^*(x) and \lambda^*(x) “offline,” then call \mathcal{L}(z^*(x), \lambda^*(x), x) in place of f(x). I’ll omit the implementation of that here, but if there’s interest, I’m happy to write it up.

Where to go from here

As I cop to in the title, the implementation I propose is a bit hacky. Ideally, I suppose it would be good for Optim to support this natively. Something like (this is basically pseudocode):

function optimize(h, z0::Vector{Float64}, x::Vector{Dual}, args...; kwargs...)
h_offline(z) = h(ForwardDiff.value.(x), z)
res = optimize(h_offline , z0, args...; kwargs...) # Find the minimum "offline"
res.minimum = func(x, res.minimizer) # Put the derivative information in the minimum
return res
end


I would love to hear your thoughts and suggestions on this. If anybody would be interested in helping adapt this for Optim.jl or as a separate package, please let me know! I don’t have much experience contributing to open-source projects but I’d like to start.

7 Likes

Note that we’re refactoring sciml_train from DiffEqFlux and making this wrapper differentiable as a package GalacticOptim.jl which wraps over Optim.jl, Flux.jl, Evolutionary.jl, NLopt.jl, QuadDIRECT.jl, and BlackBoxOptim.jl. The issue on this is here:

After refactoring sciml_train to this the plan is to add pretty much exactly this adjoint rule so that ForwardDiff2 and Zygote automatically do it. It does require that you assume local convexity at the solution of the optimization problem though.

3 Likes

This project looks very interesting! Out of curiosity, what makes this an adjoint rule?

From a quick read you wrote down the forward-rule, which would then be used in forward-mode autodiff. The adjoint method over optimization would give the reverse-mode rule. We plan to just define both ChainRules and then autodiff should be doing this in the background.

4 Likes

First of all, if you are solving \min_x \left[ \min_z h(x,z) \right], then you should simply combine it into

\min_{x,z} h(x,z)

i.e. minimize over x and z simultaneously. Assuming h(x,z) is smooth, you can then apply any mullti-variate minimization approach that you want. Adding a constraint g(x,z) = 0, or any other differentiable constraint, is also easy.

In contrast, directly nesting calls to optimize as in your suggestion is probably suboptimal. You will waste a lot of time optimizing the “inner” problem to high precision to get f(x), because you immediately throw away this effort by updating x.

A tricker case to handle is the “minimax” problem

\min_{x\in X} \left[ \max_{z\in Z} h(x,z) \right].

If Z is finite set (or can be discretized/approximated by a finite set), one approach is an “epigraph” formulation, which transforms it to an exactly equivalent problem with differentiable constraints by adding a “dummy” variable t:

\min_{x\in X, t \in \mathbb{R}} t \\ \mbox{subject to } t \ge h(x,z) \mbox{ for } z \in Z

at which point you can apply various nonlinear programming algorithms. (One of my favorites is the CCSA algorithm, which can be found in NLopt; I keep meaning to put together a pure-Julia version of this algorithm, since it is so simple and flexible.) If Z is a continuous set, however, matters are more complicated — there are a variety of “continuous minimax” algorithms proposed in the literature for this case, but libraries are harder to find.

6 Likes

You’re absolutely right that nested optimization is not the best solution for the example I gave. However, there are problems for which nesting optimizations is probably best. I intended this more as a general method for dealing with nested uses of optimize when you want to use forward autodiff. In my research, for example (and the impetus for trying to get this to work), I’m trying to solve

\begin{aligned} &\min_x \left(\sum_{i=1}^n f_i(x)^\sigma\right)^\frac{1}{\sigma}\\ \text{where}\quad f_i(x) &= \min_zh(x,z,d_i) \end{aligned}

where the x, d, and z are all vectors, and n\approx 100. In this case, solving for all the z in a single optimize is infeasible–I would have an input vector several hundred entries long.

This sort of thing comes up a lot in economics (esp. macro/IO/finance), where you have a low-dimensional set of parameters \theta and a large set of decision-makers i who make decisions z_i based on \theta.

Why is that a problem? We regularly optimize nonlinear functions of millions of variables — as long as you have gradients and use a scalable optimization algorithm (e.g. L-BFGS or CCSA, not dense BFGS) you can find local optima efficiently. It almost certainly takes far fewer function evaluations than nested optimization.

For example, here is a paper where we solved an optimization much like the one you give — our overall objective is a combination of n functions f_i, each with its own optimization parameters z_i. In our case, n = 6400, f_i involved solving Maxwell’s equations so it was fairly expensive, and z_i was a vector of 100 parameters, for 6.4 \times 10^5 parameters overall. Without simultaneous optimization it would probably not have been tractable.

6 Likes

You may be right! I’m rewriting my code now to see if I can just do it all in one loop using LBFGS. Honestly, I just didn’t think to try it with that many variables.

Edit: Yes, it’s hundreds of times faster. Thank you so much.

3 Likes

(Note that, to compute the gradients, you should either use reverse-mode differentiation or use forward-mode differentiation on each f_i separately (if z is small) and then put all the individual gradients together into one big gradient vector by the chain rule. You probably don’t want to do forward-mode automatic differentiation for so many parameters simultaneously. In general, if you are doing high-dimensional optimization you can’t apply differentiation techniques blindly.)

4 Likes

In general that’s true, but there’s an important counterexample to that, which is when the dependence of f over one variable is much more complicated and costly than the other. In that case it might be worth it to optimize the “cheap” variables in an inner loop, and nested optimization comes in handy. Eg proximal algorithms can be seen as an instance of this.

Btw, I’ve often wondered in the past if it’s possible to formalize what you just wrote. Eg is combined gradient descent more efficient than nested gradient descent in terms of combined number of iterations, or can a schedule be found for the inner iterations that makes them more or less comparable (eg in terms of dependence on condition number). If anybody knows of any relevant references I’d be interested.

Re the OP, in an ideal world solvers (optimizers, linear, nonlinear, eigenvalue, you name it) would define their own adjoints (eg by depending on ChainRules), and this would all work seamlessly. This is very possible, see eg a working prototype in https://github.com/JuliaNLSolvers/NLsolve.jl/issues/205#issuecomment-526822248 for NLsolve.

3 Likes

Yes, I agree. More generally, if there is an extremely efficient solver algorithm for a subset of the optimization variables given the others, then it may be useful to nest.

Can’t you see it as a form of Block Coordinate Descent?

Is it a model you can share by any means? We should put it in the docs under “nested optimization” because this is not the first time I’ve seen Steven bring up this point (that just doing to joint minimization is faster)

1 Like

That looks like a very neat algorithm; it would be great to have a native Julia implementation.

Incidentally, the paper mentions minimax problems, would the CCSA method work for those?

Yes (as long as it is not continuous minimax), since you can transform minimax problems into nonlinear equality constraints by an epigraph formulation as I mentioned above.

1 Like

Not right now unfortunately, but I’ll come back to this once the model is done. Also, my enthusiasm might have been a little premature. This definitely makes the iterations much faster, and the objective does appear to decrease to a sort of local minimum, but the algorithm exits with a failure status:

* Status: failure (objective increased between iterations)


As a sanity test, I’m going to see if I can get it to work with the nested-forwarddiff approach, and make sure that reversediff is working right. Any advice at this (admittedly high) level of generality?

Hard to debug without more info. Could be a bug in your code (e.g. an incorrect gradient if you are doing any part of that by hand), a limitation of floating point precision (e.g. if you are trying to optimize to a very low tolerance), or it could be a bug in the optimizer. Which optimization implementation are you using? You could try some of the ones in NLopt.

I understand. I’ll just keep debugging and see what I find. Also if the nested forwarddiff method, however slow, yields an answer, that might help debug the L-BFGS code.

Are there any key tricks when it can’t be easily approximated by a finite set?