Implicit differentiation of rootfinding problem (w/ numerical issues)

I have written an algorithm that solves

x = \log\left( \sum_i \exp(a_i + z_i y) \right)

for y given x, a_i, z_i. It uses Newton’s method, which is the trivial part, the difficulty was getting a cheap but reasonable initial guess (I use an envelope of linear tangents, it works nicely). Assume that it is implemented in

solver(a::SVector, z::SVector, x::Real, y::Real = initial_guess(a, z, x))

where initial_guess is differentiable with AD.

I would like solver to work nicely with AD, especially ForwardDiff.jl and if feasible Enzyme.jl. “Nicely” in this context means that the partials do not need to propagate through the Newton solver and it avoids overflow — see below.

An added wrinkle is that derivatives have the form

\frac{dy}{da_i} = \frac{dy}{dz_i} = - \frac{\exp(a_i + z_i y)}{\sum_i \exp(a_i + z_i y) z_i}

and

\frac{dy}{da_i} = \frac{\sum_i \exp(a_i + z_i y)}{\sum_i \exp(a_i + z_i y) z_i}

The problem with this is that a naive implicit algorithm overflows if an a_i + z_i y is even moderately large, which can easily happen.

So, what’s my best bet of hooking this into the a generic AD framework that supports automatic differentiation via a, z, and y and allows me to handle the overflow issue? (which is trivial, you just subtract the largest, but it needs to be handled nevertheless).

1 Like

Can you maybe express the optimality conditions on a log scale and use ImplicitDifferentiation.jl? I’m currently revamping the package, happy to help

1 Like

I looked at

which only supports an x \mapsto y interface and suggests stacking in a ComponentArray.

allows custom rules for the partial derivative of the residual, but I could not see a way to hook in the numerical trick.

1 Like

The residual is on a log scale, but the calculation of the derivatives involves the exponentials. I could get rid of if by subtracting something, and ADing a kernel function, but it still requires stacking up components via ComponentArrays and I am not sure it mixes well with SVectors (currently everything is non-allocating).

I can make an MWE of the solver part if that helps.

1 Like

You have a closed-form expression for the derivative. Why not just write a custom rule?

3 Likes

That is definitely a viable approach, but I am unsure about the details of how to implement it.

  1. Do I use a custom method for all combinations involving ForwardDiff.Dual and then use ChainRules.jl for the rest of the ecosystem?

  2. Is there a way I can avoid handling all the method combinations, given that there are 4 arguments? (One can assume that a and z are conformable SVectors).

1 Like

I would start with just one AD, e.g. just Enzyme. (If you have lots of parameters a_i, z_i, you won’t want to use forward-mode AD.)

It’s a good thing to learn how to do, since if you can’t write your own rules at need an AD system is very limiting.

4 Likes

In pedagogical terms, if you’ve never done it before, I would start by writing a rule using ChainRules.jl. It won’t work out of the box with Enzyme.jl, nor with ForwardDiff.jl, but it is a very well-documented and rather simple interface, so it’s great to start with.

3 Likes

(But you can import a ChainRule to Enzyme with a one-liner.)

3 Likes

I have found that they way ForwardDiff.jl work is that it still recomputes the value for each chunk (using Newton’s method). This is redundant since only the partials change.

Since this is the most costly step, I am wondering if there is a workaround (other than memoization, which is hard to make thread-safe, or using the whole input length (around 200–300) as chunk size).

1 Like

Recomputing things seems like the least of your worries with ForwardDiff. It has two problems:

  1. Like any direct application of AD, it propagates derivatives through all of the Newton steps. Whereas impliciit differentiation has a cost equivalent to only a single Newton step.
  2. The cost of forward-mode differentiation scales proportional to the number of parameters that you are differentiating with respect to. Whereas reverse-mode scales with the number of outputs — and you have the ideal case of only a single output.
3 Likes

I have worked around that by defining methods for ForwardDiff.Dual, which call the Newton step for the value and then compute the partials using closed-form.

I did not say, but of course the problem above is part of a larger calculation (approximate solution to a multisector heterogeneous agent macroeconomic model using spectral methods). The actual mapping is \mathbb{R}^{17} \to \mathbb{R}^{378} and similar. The MWE is just about optimizing a costly part.

So you are still doing a derivative computation for every Newton step, instead of only at the end?

So you don’t have a scalar loss function at the end that you are minimizing?

  1. Do I use a custom method for all combinations involving ForwardDiff.Dual and then use ChainRules.jl for the rest of the ecosystem?

Yes, unfortunately. There’s a wrapper for ForwardDiff to use ChainRules rules, but it’s very limited in what it can do (for good reasons, it’s a very tricky problem in general). It’s very annoying.

I think if you just define rules for solver (either using ImplicitDifferentiation or rolling your own) and for logsumexp (which very likely exist somewhere in the ecosystem already btw) you should be fine.

1 Like