Sorry for the late reply, here is an example of the implicit function theorem for second order derivatives.
Here, we first compute x^a such as f(x^a, p) = 0. In the example below, f(x, p) = p_1 + p_2 x + p_3 x^2 + ..., so x^a is the root of the polynomial. We use the Roots.jl
package to solve for this root.
Then, we aim to differentiate x^a(p) wrt p. As x^a is a scalar, we want the gradient \frac{\partial x^a(p)}{\partial p} vector and the hessian \frac{\partial^2 x^a(p)}{\partial p^2} matrix. We use the implicit function theorem to define these derivatives without differentiating through the root solver.
Finally, we enforce these rules using the ForwardDiffChainRules.jl
package so that x^a can be used in, say, another function g(x^a, p) that we want to differentiate wrt p using ForwardDiff
.
using Roots
using ForwardDiff
using FiniteDiff
using ForwardDiffChainRules
using ForwardDiffChainRules.ChainRulesCore
using LinearAlgebra
# Nth order polynomial evaluated at x
eval_poly(x, p) = sum(p[i]*x^(i-1) for i in eachindex(p))
# Polynomial root (this is the slow part that should never be called with Dual numbers)
function root(p::AbstractVector{T}) where T <: AbstractFloat
xᵅ = find_zero(x -> eval_poly(x, p), zero(T))
return xᵅ
end
# Gradient of the root wrt the polynomial coefficients
function ∂root∂p(p::AbstractVector{T}) where T <: AbstractFloat
# Find the root
xᵅ = find_zero(x -> eval_poly(x, p), zero(T))
# Implicit function theorem (see https://juliamath.github.io/Roots.jl/dev/roots/#Sensitivity)
# Derivatives are computed using ForwardDiff
∂f∂x = ForwardDiff.derivative(x -> eval_poly(x, p), xᵅ)
∂f∂p = ForwardDiff.gradient(p -> eval_poly(xᵅ, p), p)
return -∂f∂p/∂f∂x
end
# Hessian of the root wrt the polynomial coefficients
function ∂²root∂p²(p::AbstractVector{T}) where T <: AbstractFloat
# Find the root
xᵅ = find_zero(x -> eval_poly(x, p), zero(T))
# Second order implicit function theorem
∂f∂x = ForwardDiff.derivative(x -> eval_poly(x, p), xᵅ)
∂²f∂x² = ForwardDiff.derivative(x -> ForwardDiff.derivative(x -> eval_poly(x, p), x), xᵅ)
∂f∂p = ForwardDiff.gradient(p -> eval_poly(xᵅ, p), p)
∂²f∂p² = ForwardDiff.hessian(p -> eval_poly(xᵅ, p), p)
∂²f∂p∂x = ForwardDiff.gradient(p -> ForwardDiff.derivative(x -> eval_poly(x, p), xᵅ), p)
∂xᵅ∂p = -∂f∂p/∂f∂x
return -(∂²f∂p² + ∂²f∂p∂x*∂xᵅ∂p' + ∂xᵅ∂p*∂²f∂p∂x' + ∂²f∂x²*∂xᵅ∂p*∂xᵅ∂p') / ∂f∂x
end
# Test
p = [0.4595890823651114,
-0.12268155438606576,
0.2555413922391766,
-0.445873981165829,
0.3981791064025203,
-0.14840698579393818,
-0.09402206963799742]
root(p)
# Verify with finite differences (this should be zero)
∂root∂p(p) .- FiniteDiff.finite_difference_gradient(root, p)
# Verify with finite differences (this should be zero)
∂²root∂p²(p) .- FiniteDiff.finite_difference_hessian(root, p)
# Now enforce these coded gradient and hessian functions via defined rules
function ChainRulesCore.frule((_, Δp), ::typeof(root), p)
return root(p), dot(∂root∂p(p), Δp)
end
function ChainRulesCore.frule((_, Δp), ::typeof(∂root∂p), p)
return ∂root∂p(p), ∂²root∂p²(p) * Δp
end
@ForwardDiff_frule root(p::AbstractVector{<:ForwardDiff.Dual})
@ForwardDiff_frule ∂root∂p(p::AbstractVector{<:ForwardDiff.Dual})
# This is now using the defined rules (note that the root function is never called with Dual numbers)
ForwardDiff.gradient(root, p)
ForwardDiff.hessian(root, p)
Note that this code could benefit from several performance improvements, such as:
- using static arrays for small p vectors
- manually coding
∂f∂x
,∂f∂p
, … - having a single function returning the root, gradient and hessian
- reusing the previous root as an initial solution
- finding a way to enforce the rules so that the gradient and hessian can be returned in a single call (
ForwardDiff
requires several calls with changing partials to accumulate the gradient and hessian).