Thanks. I managed to adapt the docs example to use the frule
with ForwardDiff
.
This MWE uses bisection (to keep it simple) to solve
x^{\varepsilon_1} + x^{\varepsilon_2} = \theta \qquad \varepsilon_1, \varepsilon_2, \theta > 0
If it sees ForwardDiff.Dual
, it just invokes the frule
. This works and I don’t see any obvious problems with inference, but suggestions to improve how I hook into ForwardDiff
are welcome (the rest is really just an MWE).
using ForwardDiff, ChainRulesCore, FiniteDifferences
using ForwardDiff: value, partials, Dual
# A primitive bisection method, just to make the MWE self-contained.
function bisection(f, a::T, b::T;
xtol = abs(b - a) * √eps(max(a, b))) where {T <: AbstractFloat}
fa = f(a)
fb = f(b)
fa * fb < 0 || error("not bracketed")
for _ in 1:100
m = (a + b) / 2
fm = f(m)
abs(a - b) ≤ xtol && return m, fm
if fm * fa > 0
a, fa = m, fm
else
b, fb = m, fm
end
end
error("too many iterations")
end
bisection(f, a, b; kwargs...) = bisection(f, promote(float(a), float(b))...; kwargs...)
# analytically derived partial derivatives
∂x∂θ(ϵ1, ϵ2, θ, x) = 1 / (ϵ1 * x^(ϵ1 - 1) + ϵ2 * x^(ϵ2 - 1))
∂x∂ϵ1(ϵ1, ϵ2, θ, x) = -log(x)*x^ϵ1 * ∂x∂θ(ϵ1, ϵ2, θ, x)
∂x∂ϵ2(ϵ1, ϵ2, θ, x) = ∂x∂ϵ1(ϵ2, ϵ1, θ, x)
# solve using bisection
function _solve(ϵ1, ϵ2, θ)
(ϵ1 + ϵ2 + θ) isa ForwardDiff.Dual && error("sanity check: wrong code path")
(θ > 0 && ϵ1 > 0 && ϵ2 > 0) || throw(DomainError((; θ, ϵ1, ϵ2)))
b = max(θ^(1/ϵ1), θ^(1/ϵ2))
x, _ = bisection(x -> x^ϵ1 + x^ϵ2 - θ, 0, b)
x
end
function ChainRulesCore.frule((Δself, Δϵ1, Δϵ2, Δθ),
::typeof(_solve), ϵ1, ϵ2, θ)
x = solve(ϵ1, ϵ2, θ)
Δx = (∂x∂θ(ϵ1, ϵ2, θ, x) * Δθ + # this works with ForwardDiff.Partials …
∂x∂ϵ1(ϵ1, ϵ2, θ, x) * Δϵ1 + # … because they have + and * defined.
∂x∂ϵ2(ϵ1, ϵ2, θ, x) * Δϵ2)
return x, Δx
end
# adapted from
# https://juliadiff.org/ChainRulesCore.jl/dev/autodiff/operator_overloading.html
function _solve_dual(::Type{T}, dual_args...) where {T<:Dual}
ȧrgs = (NO_FIELDS, partials.(dual_args)...)
args = (_solve, value.(dual_args)...)
y, ẏ = frule(ȧrgs, args...)
T(y, ẏ)
end
function solve(ϵ1::T1, ϵ2::T2, θ::T3) where {T1,T2,T3}
T = promote_type(T1, T2, T3)
if T <: Dual
_solve_dual(T, ϵ1, ϵ2, θ)
else
_solve(ϵ1, ϵ2, θ)
end
end
# rudimentary checks
f(x) = solve(0.3 + x, 0.2 + x, 1.2 + x)
d1 = central_fdm(5, 1)(f, 0.0)
d2 = ForwardDiff.derivative(f, 0.0)
@show isapprox(d1, d2; atol = 1e-4)