ChainRulesCore and ForwardDiff

Thanks. I managed to adapt the docs example to use the frule with ForwardDiff.

This MWE uses bisection (to keep it simple) to solve

x^{\varepsilon_1} + x^{\varepsilon_2} = \theta \qquad \varepsilon_1, \varepsilon_2, \theta > 0

If it sees ForwardDiff.Dual, it just invokes the frule. This works and I don’t see any obvious problems with inference, but suggestions to improve how I hook into ForwardDiff are welcome (the rest is really just an MWE).

using ForwardDiff, ChainRulesCore, FiniteDifferences
using ForwardDiff: value, partials, Dual

# A primitive bisection method, just to make the MWE self-contained.
function bisection(f, a::T, b::T;
                   xtol = abs(b - a) * √eps(max(a, b))) where {T <: AbstractFloat}
    fa = f(a)
    fb = f(b)
    fa * fb < 0 || error("not bracketed")
    for _ in 1:100
        m = (a + b) / 2
        fm = f(m)
        abs(a - b) ≤ xtol && return m, fm
        if fm * fa > 0
            a, fa = m, fm
        else
            b, fb = m, fm
        end
    end
    error("too many iterations")
end

bisection(f, a, b; kwargs...) = bisection(f, promote(float(a), float(b))...; kwargs...)

# analytically derived partial derivatives
∂x∂θ(ϵ1, ϵ2, θ, x) = 1 / (ϵ1 * x^(ϵ1 - 1) + ϵ2 * x^(ϵ2 - 1))
∂x∂ϵ1(ϵ1, ϵ2, θ, x) = -log(x)*x^ϵ1 * ∂x∂θ(ϵ1, ϵ2, θ, x)
∂x∂ϵ2(ϵ1, ϵ2, θ, x) = ∂x∂ϵ1(ϵ2, ϵ1, θ, x)

# solve using bisection
function _solve(ϵ1, ϵ2, θ)
    (ϵ1 + ϵ2 + θ) isa ForwardDiff.Dual && error("sanity check: wrong code path")
    (θ > 0 && ϵ1 > 0 && ϵ2 > 0) || throw(DomainError((; θ, ϵ1, ϵ2)))
    b = max(θ^(1/ϵ1), θ^(1/ϵ2))
    x, _ = bisection(x -> x^ϵ1 + x^ϵ2 - θ, 0, b)
    x
end

function ChainRulesCore.frule((Δself, Δϵ1, Δϵ2, Δθ),
                              ::typeof(_solve), ϵ1, ϵ2, θ)
    x = solve(ϵ1, ϵ2, θ)
    Δx = (∂x∂θ(ϵ1, ϵ2, θ, x) * Δθ +   # this works with ForwardDiff.Partials …
          ∂x∂ϵ1(ϵ1, ϵ2, θ, x) * Δϵ1 + # … because they have + and * defined.
          ∂x∂ϵ2(ϵ1, ϵ2, θ, x) * Δϵ2)
    return x, Δx
end

# adapted from
# https://juliadiff.org/ChainRulesCore.jl/dev/autodiff/operator_overloading.html
function _solve_dual(::Type{T}, dual_args...) where {T<:Dual}
    ȧrgs = (NO_FIELDS,  partials.(dual_args)...)
    args = (_solve, value.(dual_args)...)
    y, ẏ = frule(ȧrgs, args...)
    T(y, ẏ)
end

function solve(ϵ1::T1, ϵ2::T2, θ::T3) where {T1,T2,T3}
    T = promote_type(T1, T2, T3)
    if T <: Dual
        _solve_dual(T, ϵ1, ϵ2, θ)
    else
        _solve(ϵ1, ϵ2, θ)
    end
end

# rudimentary checks
f(x) = solve(0.3 + x, 0.2 + x, 1.2 + x)
d1 = central_fdm(5, 1)(f, 0.0)
d2 = ForwardDiff.derivative(f, 0.0)
@show isapprox(d1, d2; atol = 1e-4)
10 Likes