Sometimes it would be useful to AD through an iterative algorithm, eg solving a parametric problem
which implicitly defines x(a), but in practice uses an iterative rootfinding method.
The math is clear (implicit differentiation), I am aware of the nice tricks for the multivariate case (adjoint method), this question is about practical Julia implementation.
For simple presentation, let’s work with scalars, and use
which we treat as a black box. We assume derivatives exist for our black box, and
allows us to obtain x'(a), then we just apply the chain rule.
This is how I implemented it:
using ForwardDiff, Roots
f(x, a) = x^3 - a # treat as a black box
solve_for_x(a) = find_zero(x -> f(x, a), (-a, a), Bisection())
function solve_for_x(a::ForwardDiff.Dual{T}) where {T}
a0 = ForwardDiff.value(a)
x0 = solve_for_x(a0)
∂x = ForwardDiff.derivative(x -> f(x, a0), x0)
∂a = ForwardDiff.derivative(a -> f(x0, a), a0)
ForwardDiff.Dual{T}(x0, (-∂a/∂x) * ForwardDiff.partials(a))
end
It seems to work OK:
using Test
function test1(z)
a = exp(z + 1)
x = solve_for_x(a)
x^2
end
@test ForwardDiff.derivative(test1, 0.7) ≈
ForwardDiff.derivative(z -> abs2(cbrt(exp(z + 1))), 0.7)
Questions
-
I am assuming that the inner
ForwardDiff.derivative
takes care of perturbation confusion, is this correct? -
Anything else I could improve? It seems to be type stable.
-
How would one generalize this to other AD libraries?