AD step for iterative solution

Sometimes it would be useful to AD through an iterative algorithm, eg solving a parametric problem

f(x, a) = 0

which implicitly defines x(a), but in practice uses an iterative rootfinding method.

The math is clear (implicit differentiation), I am aware of the nice tricks for the multivariate case (adjoint method), this question is about practical Julia implementation.

For simple presentation, let’s work with scalars, and use

f(x, a) = x^3 - a

which we treat as a black box. We assume derivatives exist for our black box, and

\frac{\partial x}{\partial a} f_1 + f_2 = 0

allows us to obtain x'(a), then we just apply the chain rule.

This is how I implemented it:

using ForwardDiff, Roots

f(x, a) = x^3 - a # treat as a black box

solve_for_x(a) = find_zero(x -> f(x, a), (-a, a), Bisection())

function solve_for_x(a::ForwardDiff.Dual{T}) where {T}
    a0 = ForwardDiff.value(a)
    x0 = solve_for_x(a0)
    ∂x = ForwardDiff.derivative(x -> f(x, a0), x0)
    ∂a = ForwardDiff.derivative(a -> f(x0, a), a0)
    ForwardDiff.Dual{T}(x0, (-∂a/∂x) * ForwardDiff.partials(a))
end

It seems to work OK:

using Test

function test1(z)
    a = exp(z + 1)
    x = solve_for_x(a)
    x^2
end

@test ForwardDiff.derivative(test1, 0.7) ≈
    ForwardDiff.derivative(z -> abs2(cbrt(exp(z + 1))), 0.7)

Questions

  1. I am assuming that the inner ForwardDiff.derivative takes care of perturbation confusion, is this correct?

  2. Anything else I could improve? It seems to be type stable.

  3. How would one generalize this to other AD libraries?

4 Likes

Yes, this is similar to how I do it (except I need the Jacobian). They’re is no risk for perturbation confusing because no duals are actually being passed to another derivative call.

2 Likes

Yup. Ideally solvers would implement this on their own and do it automagically. See https://github.com/JuliaNLSolvers/NLsolve.jl/issues/205

1 Like

I am not sure that this belongs in the solver, I think it is more generic. The solver helps me obtain x in F(x, a) = 0, but the derivative should be independent of the solver once I have the solution.

What I would find useful is a crude wrapper that would allow me to make AD work for a generic f(x, a) = 0 (incl vector arguments) simply by providing a g: a \mapsto x (for which I would be free to pick the solver).

Similarly for x(a) = \arg\max_x f(x, a).

I am also interested in this kind of AD application in my field (topology optimization). But for me, often I am not interested in x itself or its Jacobian thereof, I am interested in another explicit function of x, h(x), and its gradient. Also often in my field, the gradient of (h \cdot x)(a) can include some cancellations and case-specific optimizations that are difficult to automate in the general case. Perhaps, the simplest optimization I need is something like K^{-1}KK^{-1} \nabla \to K^{-1} \nabla. The second K^-1 comes from one Jacobian, whereas K^{-1}K is from the next Jacobian. This can save one linear system solve when accumulating the gradient with reverse differentiation. For this to be possible, the accumulation of the Jacobian mat-vecs needs to be lazy and somehow, through dispatch (e.g. LazyArrays.jl) or compiler optimizations, this cancellation needs to be made. Perhaps, if the Julia compiler can learn this optimization, Zygote may be able to utilize this somehow since it is a lazy source-to-source transformation IIUC. But I am far from an expert on the Zygote internals, so I may be wrong. Perhaps, Zygote and LazyArrays can be used together in this case to trigger these kind of lazy optimizations. @MikeInnes sorry for pinging you, but do you think this is feasible given how Zygote does things?

Personally, I am just manually deriving and hard-coding everything I need right now to get a guaranteed performance although I would love to explore the use of AD frameworks more to modularize some of the mathematical derivations and automate some of the manual optimizations I am doing. But even without the above cancellation, this feature can still be useful for complex multi-disciplinary physical systems like the ones discussed here.

I am not sure Zygote will do algebraic simplifications like this for you — AFAIK if the AST transforms to something like a

A * inv(A)

that will not be replaced by a

I

because (with floating point) that may not be the same at all.

That said, it would be great if one could coax Zygote to AD through iterative solves like in the OP. I will package what I have for ForwardDiff today and will be happy to consider PRs.

The PR here https://github.com/JuliaDiff/ForwardDiff.jl/pull/165 might have some use (or not).

1 Like

Could this be better handled by switching over to ChainRules.jl?

Probably not. The main thing about that PR is just how to get the gradients to “propagate properly” (as in the last line in the overloaded function in the first post of this thread). And this has mostly to do with the AD implementation and not the large list of derivatives.

2 Likes

I’ve been sketching a similar idea to generalize the adjoint of fixed-point methods which solve

T^* = f(T^*, \theta).

The adjoint is then (according to this paper )

\bar{\theta} = \sum_{n=0}^\infty \bar{T^*} \left[\frac{\partial f(T^*, \theta)}{\partial T^*} \right]^n \frac{\partial f(T^*, \theta)}{\partial \theta}

where we can avoid saving intermediate results and iterate until convergence.

Anyway, for my case I thought it might make sense to define a package that’s just a standard interface for fixed-point functions. A sketch can be found in this gist where the adjoint looks like

function fixedpointbackward(next, r, n)
    _, back = Zygote.forward(next,r,n)
    back1 = x -> back(x)[1]
    back2 = x -> back(x)[2]
    function backΔ(Δ) # as in 'Differential Programming Tensor Networks', arXiv 1903.09650
        grad = back2(Δ)
        for g in IterTools.imap(back2,Iterators.drop(IterTools.iterated(back1, Δ),1))
            grad += g
            norm(g) < 1e-12 && break
        end
        grad
    end
    return backΔ
end


@Zygote.adjoint function fixedpointAD(f, guess, n, stopfun)
    r = fixedpoint(f, guess, n, stopfun)
    return r, Δ -> (nothing, nothing, fixedpointbackward(f, r, n)(Δ), nothing)
end

So maybe the solution could be to have small packages that provide standard interfaces to certain functions, for which these alternative adjoints can be defined.

1 Like

That’s a bad way to compute the adjoint, which is really just solving a linear system. The formula is saying that (1-A)^-1 = sum of A^n, which is a crude way to solve it: a better one is to use an iterative method, as implemented in eg GitHub - JuliaLinearAlgebra/IterativeSolvers.jl: Iterative algorithms for solving linear systems, eigensystems, and singular value problems. The nonlinear case is: solving T = f(T) by Tn+1 = f(Tn) is bad, use something better (eg Anderson acceleration, implemented in JuliaNLSolvers · GitHub)

1 Like

The point is that the code you have in the OP could be implemented generically as a diff rule for nlsolve(f,x0,a), where nlsolve(f,x0,a) solves f(x,a)=0 for x with initial guess x0. In the multivariate case, you’d call a gmres or something to perform the inversion using only matvecs with df/dx, which can be performed by forward diff.