Custom rule for differentiating through Newton solver using ForwardDiff: works for gradient, fails for hessian

Hi all,

I have a function obj(p) that solves for the parameter t of a Bezier curve B(t) so that the arclength is equal to some arbitrary value (2 in this kind-of-minimal example). It uses the Newton’s method.
The Bezier curve is parametrized by 4 3D points, so that the input vector p has 12 elements.

using LinearAlgebra
using StaticArrays
using GeometryBasics
using ForwardDiff
import ForwardDiff.value
using FiniteDiff
using QuadGK

# Solves the equation f(x) = 0 for x using Newton's method
function newton(f, df, x₀, difftol=1e-12, abstol=1e-16, maxiter=100)
	x = x₀
	it = 0
	fx = f(x)
	Δ = typemax(x)
	while abs(fx) > abstol && abs(Δ) > difftol && it < maxiter
		Δ = -fx/df(x)
		x += Δ
		fx = f(x)
		it += 1
	end
	return it == maxiter ? error("Newton's method did not converge") : x
end



const Gx7, Wx7 = gauss(7) .|> SVector{7}

# Integrate a function f between a and b using 7 point Gauss quadrature
function integrate(f, a, b)
	x = (b-a)/2 * Gx7 .+ (b+a)/2
	w = (b-a)/2 * Wx7
	return mapreduce((xi, wi) -> f(xi) * wi, +, x, w)
end


# Computes the arclength of a curve f(t) between t₀ and t
arclength(df, t₀, t) = integrate(x -> norm(f(x)), t₀, t)
# The derivative of the arclength function wrt t
d_arclength_dt(df, t) = norm(f(t))


# Find the arclength parameter t so that the arclength of the Bezier curve B(t) between 0 and t is 2.
function obj(p)

	# Bezier curve control points
	P₁ = Point3(p[1], p[2], p[3])
	P₂ = Point3(p[4], p[5], p[6])
	P₃ = Point3(p[7], p[8], p[9])
	P₄ = Point3(p[10], p[11], p[12])

	# Bezier curve equation
	B(t) = (1-t)^3*P₁ + 3*(1-t)^2*t*P₂ + 3*(1-t)*t^2*P₃ + t^3*P₄
	# Derivative wrt t
	dB(t) = -3*(1-t)^2*P₁ + 3*(1-t)^2*P₂ - 6*(1-t)*t*P₂ + 6*(1-t)*t*P₃ - 3*t^2*P₃ + 3*t^2*P₄

	# In order to find the parameter t that solves arclength = 2., we need to solve fun(x) = 0 with:
	fun(x) = arclength(dB, 0., x) - 2.
	d_fun_dx(x) = d_arclength_dt(dB, x)

	# We can use Newton's method to solve fun(x) = 0
	xstar = newton(fun, d_fun_dx, 0.)

	return xstar

end

x0 = rand(12)
obj(x0)

I now want to compute the gradient and hessian of this function wrt the input vector p. This works fine using ForwardDiff, as the newton solver is differentiable, but much better performance can be achieved by defining a custom rule using the implicit function theorem, as described here: An overview of Roots · Roots

# Specialized version of obj(p) for ForwardDiff.Dual numbers using the implicit function theorem
function obj(p::AbstractVector{T}) where T <: ForwardDiff.Dual

	# control points
	P₁ = Point3(p[1], p[2], p[3])
	P₂ = Point3(p[4], p[5], p[6])
	P₃ = Point3(p[7], p[8], p[9])
	P₄ = Point3(p[10], p[11], p[12])

	# Bezier curve equation
	B(t) = (1-t)^3*P₁ + 3*(1-t)^2*t*P₂ + 3*(1-t)*t^2*P₃ + t^3*P₄
	# Derivative wrt t
	dB(t) = -3*(1-t)^2*P₁ + 3*(1-t)^2*P₂ - 6*(1-t)*t*P₂ + 6*(1-t)*t*P₃ - 3*t^2*P₃ + 3*t^2*P₄
	# Derivative wrt t - real valued version
	dB_real(t) = -3*(1-t)^2*value.(P₁) + 3*(1-t)^2*value.(P₂) - 6*(1-t)*t*value.(P₂) + 6*(1-t)*t*value.(P₃) - 3*t^2*value.(P₃) + 3*t^2*value.(P₄)

	# In order to find the parameter t that solves arclength = 2., we need to solve fun(x) = 0 with:
	fun(x) = arclength(dB_real, 0., x) - 2.
	d_fun_dx(x) = d_arclength_dt(dB_real, x)

	# Use the implicit function theorem to differentiate through the root finding process
	# See https://juliamath.github.io/Roots.jl/dev/roots/#Sensitivity
	xstar = newton(fun, d_fun_dx, 0.)
	∂f∂x = d_fun_dx(xstar)
	∂f∂p = arclength(dB, 0., xstar) - 2. # Evaluating fun with dB (which carries duals of the input vector p) instead of dB_real gives the derivative we are looking for
	xdual = -∂f∂p/∂f∂x

	return xdual

end

This also works quite well and can be checked using FiniteDiff:

ForwardDiff.gradient(obj, x0) - FiniteDiff.finite_difference_gradient(obj, x0) # should be approx. zero - OK

However something strange happens with the hessian: when using the default value of 12 for the chunk size, everything looks fine:

cfg = ForwardDiff.HessianConfig(obj, x0, ForwardDiff.Chunk(12))
ForwardDiff.hessian(obj, x0, cfg) - FiniteDiff.finite_difference_hessian(obj, x0) # should be approx. zero - OK!

But the result depends on the chunk size. For example, this fails:

cfg = ForwardDiff.HessianConfig(obj, x0, ForwardDiff.Chunk(1))
ForwardDiff.hessian(obj, x0, cfg) - FiniteDiff.finite_difference_hessian(obj, x0) # should be approx. zero - NOT OK!

Note that I didn’t define a custom rule for the hessian, so it goes through the second obj function. It could still be improved, but I don’t understand why the hessian is wrong when the chunk size is smaller than the length of the input vector.

Does anyone know what is wrong here?

Besides, if you could point me to an expression to write the custom rule for the hessian using the implicit function, that would be great! I could only find the expression for when p is a scalar.

Thanks!

1 Like

You may find GitHub - gdalle/ImplicitDifferentiation.jl: Automatic differentiation of implicit functions and GitHub - ThummeTo/ForwardDiffChainRules.jl useful.

1 Like

Thanks for the recommendations! I ended up manually coding the gradient and hessian for the function and using the ForwardDiffChainRules.jl package to enforce this rule, now it’s working.

I do not understand what you mean. Do you say that you implemented the hessian of the function defined by the implicit function theorem?

I would be very interested in an example!

Sorry for the late reply, here is an example of the implicit function theorem for second order derivatives.

Here, we first compute x^a such as f(x^a, p) = 0. In the example below, f(x, p) = p_1 + p_2 x + p_3 x^2 + ..., so x^a is the root of the polynomial. We use the Roots.jl package to solve for this root.

Then, we aim to differentiate x^a(p) wrt p. As x^a is a scalar, we want the gradient \frac{\partial x^a(p)}{\partial p} vector and the hessian \frac{\partial^2 x^a(p)}{\partial p^2} matrix. We use the implicit function theorem to define these derivatives without differentiating through the root solver.

Finally, we enforce these rules using the ForwardDiffChainRules.jl package so that x^a can be used in, say, another function g(x^a, p) that we want to differentiate wrt p using ForwardDiff.

using Roots
using ForwardDiff
using FiniteDiff
using ForwardDiffChainRules
using ForwardDiffChainRules.ChainRulesCore
using LinearAlgebra

# Nth order polynomial evaluated at x
eval_poly(x, p) = sum(p[i]*x^(i-1) for i in eachindex(p))

# Polynomial root (this is the slow part that should never be called with Dual numbers)
function root(p::AbstractVector{T}) where T <: AbstractFloat
    xᵅ = find_zero(x -> eval_poly(x, p), zero(T))
    return xᵅ
end

# Gradient of the root wrt the polynomial coefficients
function ∂root∂p(p::AbstractVector{T}) where T <: AbstractFloat

    # Find the root
    xᵅ = find_zero(x -> eval_poly(x, p), zero(T))

    # Implicit function theorem (see https://juliamath.github.io/Roots.jl/dev/roots/#Sensitivity)
    # Derivatives are computed using ForwardDiff
    ∂f∂x = ForwardDiff.derivative(x -> eval_poly(x, p), xᵅ)
    ∂f∂p = ForwardDiff.gradient(p -> eval_poly(xᵅ, p), p)

    return -∂f∂p/∂f∂x

end

# Hessian of the root wrt the polynomial coefficients
function ∂²root∂p²(p::AbstractVector{T}) where T <: AbstractFloat
    
    # Find the root
    xᵅ = find_zero(x -> eval_poly(x, p), zero(T))

    # Second order implicit function theorem
    ∂f∂x = ForwardDiff.derivative(x -> eval_poly(x, p), xᵅ)
    ∂²f∂x² = ForwardDiff.derivative(x -> ForwardDiff.derivative(x -> eval_poly(x, p), x), xᵅ)
    ∂f∂p = ForwardDiff.gradient(p -> eval_poly(xᵅ, p), p)
    ∂²f∂p² = ForwardDiff.hessian(p -> eval_poly(xᵅ, p), p)
    ∂²f∂p∂x = ForwardDiff.gradient(p -> ForwardDiff.derivative(x -> eval_poly(x, p), xᵅ), p)

    ∂xᵅ∂p = -∂f∂p/∂f∂x
    return -(∂²f∂p² + ∂²f∂p∂x*∂xᵅ∂p' + ∂xᵅ∂p*∂²f∂p∂x' + ∂²f∂x²*∂xᵅ∂p*∂xᵅ∂p') / ∂f∂x
    
end


# Test
p = [0.4595890823651114,
    -0.12268155438606576,
    0.2555413922391766,
    -0.445873981165829,
    0.3981791064025203,
    -0.14840698579393818,
    -0.09402206963799742]


root(p)

# Verify with finite differences (this should be zero)
∂root∂p(p) .- FiniteDiff.finite_difference_gradient(root, p)

# Verify with finite differences (this should be zero)
∂²root∂p²(p) .- FiniteDiff.finite_difference_hessian(root, p)



# Now enforce these coded gradient and hessian functions via defined rules

function ChainRulesCore.frule((_, Δp), ::typeof(root), p)
    return root(p), dot(∂root∂p(p), Δp)
end

function ChainRulesCore.frule((_, Δp), ::typeof(∂root∂p), p)
    return ∂root∂p(p), ∂²root∂p²(p) * Δp
end

@ForwardDiff_frule root(p::AbstractVector{<:ForwardDiff.Dual})
@ForwardDiff_frule ∂root∂p(p::AbstractVector{<:ForwardDiff.Dual})


# This is now using the defined rules (note that the root function is never called with Dual numbers)
ForwardDiff.gradient(root, p)
ForwardDiff.hessian(root, p)

Note that this code could benefit from several performance improvements, such as:

  • using static arrays for small p vectors
  • manually coding ∂f∂x, ∂f∂p, …
  • having a single function returning the root, gradient and hessian
  • reusing the previous root as an initial solution
  • finding a way to enforce the rules so that the gradient and hessian can be returned in a single call (ForwardDiff requires several calls with changing partials to accumulate the gradient and hessian).

We just added support for be forwarddiff as an extension under 1.9. Can you see if it works as desired for your use case?

What you are doing is pretty much the same as following this simple example https://gdalle.github.io/ImplicitDifferentiation.jl/dev/examples/1_unconstrained_optimization/#Higher-order-differentiation.

And seeing this thread I released v0.3 of ImplicitDifferentiation.jl, which should allow you to access second-order derivatives seamlessly

The new version of Roots.jl with Julia 1.9 with the custom rule for Duals works as expected for first order derivative (gradient) but not for the hessian:

function root(p::AbstractVector{T}) where T <: Real
    xᵅ = find_zero(eval_poly, (0., 10.), Bisection(), p)
    return xᵅ
end

ForwardDiff.gradient(root, p) #OK

ForwardDiff.hessian(root, p)

ERROR: MethodError: no method matching _mul_partials(::ForwardDiff.Partials{7, Float64}, ::ForwardDiff.Partials{1, ForwardDiff.Dual{ForwardDiff.Tag{typeof(root), Float64}, Float64, 7}}, ::ForwardDiff.Dual{ForwardDiff.Tag{typeof(root), Float64}, Float64, 7}, ::Float64)

That looks promising!
But I cannot make it work even for the first order:


using Roots
using ForwardDiff
using ImplicitDifferentiation
using FiniteDiff

eval_poly(x, p) = sum(p[i]*x^(i-1) for i in eachindex(p))

# Polynomial root (this is the slow part that should never be called with Dual numbers)
function root(p)
    f(x) = eval_poly(x, p)
    xᵅ = find_zero(f, 0.)
    return xᵅ
end


function zero_gradient(p, x)
    f(x) = eval_poly(x[1], p)
    df(x) = ForwardDiff.gradient(f, x)
    return df(x)
end

implicit = ImplicitFunction(p -> [root(p)], zero_gradient)

@ForwardDiff_frule (f::typeof(implicit))(p::AbstractArray{<:ForwardDiff.Dual}; kwargs...)

p = [0.4595890823651114,
    -0.12268155438606576,
    0.2555413922391766,
    -0.445873981165829,
    0.3981791064025203,
    -0.14840698579393818,
    -0.09402206963799742]

ForwardDiff.gradient(root, p) # Not equal to the verification below
FiniteDiff.finite_difference_gradient(root, p)

Disregard this, the zero_gradient function is obviously wrong. It works fine now:

function zero_gradient(p, x)
    return [eval_poly(x[1], p)]
end

linear_solver(A, b) = (Matrix(A) \ b, (solved=true,))
implicit = ImplicitFunction(p -> [root(p)], zero_gradient, linear_solver)

@ForwardDiff_frule (f::typeof(implicit))(p::AbstractArray{<:ForwardDiff.Dual}; kwargs...)

D(p) = ForwardDiff.gradient(p -> implicit(p)[1], p)
DD(p) = ForwardDiff.hessian(p -> implicit(p)[1], p)

However performance is not that great, this is probably more suited to problems with larger dimensions.

It would be interesting to dig into why performance is not great. Do you have any point of comparison?

Sure, let’s compare with the manually implemented versions:


# Gradient of the root wrt the polynomial coefficients
function ∂root∂p(p::AbstractVector{T}) where T 

    # Find the root
    xᵅ = find_zero(x -> eval_poly(x, p), zero(T))

    # Implicit function theorem (see https://juliamath.github.io/Roots.jl/dev/roots/#Sensitivity)
    # Derivatives are computed using ForwardDiff
    ∂f∂x = ForwardDiff.derivative(x -> eval_poly(x, p), xᵅ)
    ∂f∂p = ForwardDiff.gradient(p -> eval_poly(xᵅ, p), p)

    return -∂f∂p/∂f∂x

end

# Hessian of the root wrt the polynomial coefficients
function ∂²root∂p²(p::AbstractVector{T}) where T 
    
    # Find the root
    xᵅ = find_zero(x -> eval_poly(x, p), zero(T))

    # Second order implicit function theorem
    ∂f∂x = ForwardDiff.derivative(x -> eval_poly(x, p), xᵅ)
    ∂²f∂x² = ForwardDiff.derivative(x -> ForwardDiff.derivative(x -> eval_poly(x, p), x), xᵅ)
    ∂f∂p = ForwardDiff.gradient(p -> eval_poly(xᵅ, p), p)
    ∂²f∂p² = ForwardDiff.hessian(p -> eval_poly(xᵅ, p), p)
    ∂²f∂p∂x = ForwardDiff.gradient(p -> ForwardDiff.derivative(x -> eval_poly(x, p), xᵅ), p)

    ∂xᵅ∂p = -∂f∂p/∂f∂x
    return -(∂²f∂p² + ∂²f∂p∂x*∂xᵅ∂p' + ∂xᵅ∂p*∂²f∂p∂x' + ∂²f∂x²*∂xᵅ∂p*∂xᵅ∂p') / ∂f∂x
    
end

using BenchmarkTools

@btime D($p)
  74.413 μs (561 allocations: 35.00 KiB)
@btime ∂root∂p($p)
  2.599 μs (9 allocations: 1.78 KiB)

@btime DD($p)
  629.330 μs (4478 allocations: 328.50 KiB)
@btime ∂²root∂p²($p)
  9.117 μs (37 allocations: 22.75 KiB)

In my case, x is a scalar, so it doesn’t require solving a linear system, which probably helps.

Yeah I’m not very surprised that plain Roots.jl + ForwardDiff.jl is very efficient. ImplicitDifferentiation.jl is most interesting when it saves you the trouble of autodiff-ing an iterative procedure in reverse mode

Also to be fair I’m comparing the manual implementations to the custom rules defined for ForwardDiff, which require multiple function evaluations for accumulation.

A better comparison would be:

function ChainRulesCore.frule((_, Δp), ::typeof(root), p)
    return root(p), dot(∂root∂p(p), Δp)
end

function ChainRulesCore.frule((_, Δp), ::typeof(∂root∂p), p)
    return ∂root∂p(p), ∂²root∂p²(p) * Δp
end

@ForwardDiff_frule root(p::AbstractVector{<:ForwardDiff.Dual})
@ForwardDiff_frule ∂root∂p(p::AbstractVector{<:ForwardDiff.Dual})

@btime ForwardDiff.gradient($root, $p)
  51.985 μs (148 allocations: 19.25 KiB)  # D($p) was 74.413 μs
@btime ForwardDiff.hessian($root, $p)
  1.229 ms (3842 allocations: 1.38 MiB)  # DD($p) was 629.330 μs 

That’s a significant slowdown though, ~30x in the first order case and ~70x in the second order case. @touste please open an issue in ImplicitDifferentiation so we can look into it. We should be able to match the performance of your manual implementation, or at least be close to it.

It would be nice to profile the ID.jl implementation. I suspect most of the overhead is coming from the matrix conversion part of Matrix(A) \ b.

1 Like

Yeah that’s what I think too, although in the one-dimensional case it shouldn’t matter too much. Which brings us back to Support several inputs instead of just `x` · Issue #33 · gdalle/ImplicitDifferentiation.jl · GitHub for supporting autodiff-able linear solvers

So I profiled the call to D(p), here are the results:

A couple of remarks:

  • There seems to be some type instabilities in the custom forward rule for ImplificFunction
  • I don’t understand why, but when I execute D(p), the root function is called 7 times with the same input vector each time. I guess this is due to this loop here. Is there any way to optimize this?

Full code for reference:


# Nth order polynomial evaluated at x
eval_poly(x, p) = sum(p[i]*x^(i-1) for i in eachindex(p))

# Polynomial root (this is the slow part that should never be called with Dual numbers)
function root(p::AbstractVector{T}) where T <: Real
    f(x) = eval_poly(x, p)
    xᵅ = find_zero(f, 0.)
    return xᵅ
end

function zero_gradient(p, x)
    return [eval_poly(x[1], p)]
end

linear_solver(A, b) = (Matrix(A) \ b, (solved=true,))
implicit = ImplicitFunction(p -> [root(p)], zero_gradient, linear_solver)

@ForwardDiff_frule (f::typeof(implicit))(p::AbstractArray{<:ForwardDiff.Dual}; kwargs...)

D(p) = ForwardDiff.gradient(p -> implicit(p)[1], p)
DD(p) = ForwardDiff.hessian(p -> implicit(p)[1], p)


# Test
p = [0.4595890823651114,
    -0.12268155438606576,
    0.2555413922391766,
    -0.445873981165829,
    0.3981791064025203,
    -0.14840698579393818,
    -0.09402206963799742]

root(p)

@profview @btime D($p)
1 Like