ChainRulesCore and ForwardDiff

If I want to define derivatives for a function to work with ForwardDiff, is it sufficient to define a ChainRulesCore.frule, or is there something else?

The rule does not seem to be used, instead the original function is called with ForwardDiff.Dual arguments.

No, ForwardDiff does not currently use ChainRules, so you will have to add your rules as done in ForwardDiff.jl/dual.jl at 909976d719fdbd5fec91a159c6e2d808c45a770f · JuliaDiff/ForwardDiff.jl · GitHub.

There is https://github.com/YingboMa/ForwardDiff.jl which does use ChainRules, but I don’t think that is currently still being actively developed.

4 Likes

Thanks. I managed to adapt the docs example to use the frule with ForwardDiff.

This MWE uses bisection (to keep it simple) to solve

x^{\varepsilon_1} + x^{\varepsilon_2} = \theta \qquad \varepsilon_1, \varepsilon_2, \theta > 0

If it sees ForwardDiff.Dual, it just invokes the frule. This works and I don’t see any obvious problems with inference, but suggestions to improve how I hook into ForwardDiff are welcome (the rest is really just an MWE).

using ForwardDiff, ChainRulesCore, FiniteDifferences
using ForwardDiff: value, partials, Dual

# A primitive bisection method, just to make the MWE self-contained.
function bisection(f, a::T, b::T;
                   xtol = abs(b - a) * √eps(max(a, b))) where {T <: AbstractFloat}
    fa = f(a)
    fb = f(b)
    fa * fb < 0 || error("not bracketed")
    for _ in 1:100
        m = (a + b) / 2
        fm = f(m)
        abs(a - b) ≤ xtol && return m, fm
        if fm * fa > 0
            a, fa = m, fm
        else
            b, fb = m, fm
        end
    end
    error("too many iterations")
end

bisection(f, a, b; kwargs...) = bisection(f, promote(float(a), float(b))...; kwargs...)

# analytically derived partial derivatives
∂x∂θ(ϵ1, ϵ2, θ, x) = 1 / (ϵ1 * x^(ϵ1 - 1) + ϵ2 * x^(ϵ2 - 1))
∂x∂ϵ1(ϵ1, ϵ2, θ, x) = -log(x)*x^ϵ1 * ∂x∂θ(ϵ1, ϵ2, θ, x)
∂x∂ϵ2(ϵ1, ϵ2, θ, x) = ∂x∂ϵ1(ϵ2, ϵ1, θ, x)

# solve using bisection
function _solve(ϵ1, ϵ2, θ)
    (ϵ1 + ϵ2 + θ) isa ForwardDiff.Dual && error("sanity check: wrong code path")
    (θ > 0 && ϵ1 > 0 && ϵ2 > 0) || throw(DomainError((; θ, ϵ1, ϵ2)))
    b = max(θ^(1/ϵ1), θ^(1/ϵ2))
    x, _ = bisection(x -> x^ϵ1 + x^ϵ2 - θ, 0, b)
    x
end

function ChainRulesCore.frule((Δself, Δϵ1, Δϵ2, Δθ),
                              ::typeof(_solve), ϵ1, ϵ2, θ)
    x = solve(ϵ1, ϵ2, θ)
    Δx = (∂x∂θ(ϵ1, ϵ2, θ, x) * Δθ +   # this works with ForwardDiff.Partials …
          ∂x∂ϵ1(ϵ1, ϵ2, θ, x) * Δϵ1 + # … because they have + and * defined.
          ∂x∂ϵ2(ϵ1, ϵ2, θ, x) * Δϵ2)
    return x, Δx
end

# adapted from
# https://juliadiff.org/ChainRulesCore.jl/dev/autodiff/operator_overloading.html
function _solve_dual(::Type{T}, dual_args...) where {T<:Dual}
    ȧrgs = (NO_FIELDS,  partials.(dual_args)...)
    args = (_solve, value.(dual_args)...)
    y, ẏ = frule(ȧrgs, args...)
    T(y, ẏ)
end

function solve(ϵ1::T1, ϵ2::T2, θ::T3) where {T1,T2,T3}
    T = promote_type(T1, T2, T3)
    if T <: Dual
        _solve_dual(T, ϵ1, ϵ2, θ)
    else
        _solve(ϵ1, ϵ2, θ)
    end
end

# rudimentary checks
f(x) = solve(0.3 + x, 0.2 + x, 1.2 + x)
d1 = central_fdm(5, 1)(f, 0.0)
d2 = ForwardDiff.derivative(f, 0.0)
@show isapprox(d1, d2; atol = 1e-4)
10 Likes

Nice!.
Do you think we could make a macro to do that.
Like a @import_frule foo or something?

1 Like

I haven’t found a clean way to add an AD-method for ForwardDiff.Duals, because in order to dispatch on the AD method I need a trait like the promoted type above.

Maybe some helper functions could make this easier and cleaner than a macro. Also, obtaining partials for non-duals is OK but would be wasteful for arrays and similar, where specializing on them not being duals would work better.

1 Like

Yeah, I am pretty convinced one wants to always keep the dual on the outside
A Dual of Arrays, not an Array of Duals.
Because it makes this easy, and also means you can use BLAS etc.

ForwardDiff2 did that

1 Like

If ForwardDiff2 is maybe not looking like it’s going to become a full replacement, is there anything else in the pipelines for forward-mode AD beyond ForwardDiff? I am in general a very happy use of ForwardDiff, but I sometimes have to do some hacky things that are less elegant than @Tamas_Papp’s solution here. I have a pattern I’ve ended up using a few times of writing methods like

function myfunction(x::Dual{T,V,N}) where{T,V<:AbstractFloat,N}
    # manual first derivative of my function here...
end

which are definitely kind of scary because I don’t really know what I’m doing and have gotten partials wrong in non-obvious ways before. I should probably start using Tamas’ method here instead, but in general I’d be curious to hear about forward-mode stuff. A lot of my problems are, say, 10 dimensional, and so the extra overhead of reverse-mode makes it less appealing.

Note that my solution above is making ForwardDiff use existing partials defined by ChainRulesCore.frule — in other words, you still need to get them right, but hopefully in one place only without duplicating code.

1 Like

Diffractor (@keno) has both forwards and reverse mode.
And it’s forward mode does use ChainRulesCore.frule natively.
(and it’s reverse does use rrule)
DIffractor has not been released yet, but I get the impression it is getting close

2 Likes

Hi all,

I am seeking for a “default procedure” to get ForwardDiff working together with already defined frules. The code example from @Tamas_Papp looks pretty well, but I need a more arbitrary way to solve this for array arguments. The use of array arguements results in ForwardDiff.Dual becoming Vector{ForwardDiff.Dual{...}} and T = promote_type(T1, T2, T3) results in T=any if I mix scalar and vector arguments.

I am playing around the whole day, but I can’t figure out a way to solve the problem for arbitrary functions, meaning for functions with mixed scalar and vector arguments.

Everyone with suggestions would make my day :slight_smile:
Thanks in advance!

PS: If there is a good tutorial on how to build custom AD rules for ForwardDiff for functions with multiple and/or vector arguments, it would be a nice hint, too.

1 Like

This is a rather late reply, hope it is still useful. For something similar to the code above to work, you need to extract the eltype of array arguments and promote those.

2 Likes

Thanks for the reply.
Yes, we (prototypically) solved it in a similar way (cheching the array type and converting every element). As soon as I spend more time on this again, I will try to build and post a generic solution.

Just stumbled upon this thread, and I seem to recall that @mohamed82008 had an answer in the works

Check NonconvexUtils.jl/runtests.jl at main · JuliaNonconvex/NonconvexUtils.jl · GitHub