ChainRulesCore and ForwardDiff

If I want to define derivatives for a function to work with ForwardDiff, is it sufficient to define a ChainRulesCore.frule, or is there something else?

The rule does not seem to be used, instead the original function is called with ForwardDiff.Dual arguments.

No, ForwardDiff does not currently use ChainRules, so you will have to add your rules as done in ForwardDiff.jl/dual.jl at 909976d719fdbd5fec91a159c6e2d808c45a770f · JuliaDiff/ForwardDiff.jl · GitHub.

There is GitHub - YingboMa/ForwardDiff.jl: Forward Mode Automatic Differentiation for Julia which does use ChainRules, but I don’t think that is currently still being actively developed.


Thanks. I managed to adapt the docs example to use the frule with ForwardDiff.

This MWE uses bisection (to keep it simple) to solve

x^{\varepsilon_1} + x^{\varepsilon_2} = \theta \qquad \varepsilon_1, \varepsilon_2, \theta > 0

If it sees ForwardDiff.Dual, it just invokes the frule. This works and I don’t see any obvious problems with inference, but suggestions to improve how I hook into ForwardDiff are welcome (the rest is really just an MWE).

using ForwardDiff, ChainRulesCore, FiniteDifferences
using ForwardDiff: value, partials, Dual

# A primitive bisection method, just to make the MWE self-contained.
function bisection(f, a::T, b::T;
                   xtol = abs(b - a) * √eps(max(a, b))) where {T <: AbstractFloat}
    fa = f(a)
    fb = f(b)
    fa * fb < 0 || error("not bracketed")
    for _ in 1:100
        m = (a + b) / 2
        fm = f(m)
        abs(a - b) ≤ xtol && return m, fm
        if fm * fa > 0
            a, fa = m, fm
            b, fb = m, fm
    error("too many iterations")

bisection(f, a, b; kwargs...) = bisection(f, promote(float(a), float(b))...; kwargs...)

# analytically derived partial derivatives
∂x∂θ(ϵ1, ϵ2, θ, x) = 1 / (ϵ1 * x^(ϵ1 - 1) + ϵ2 * x^(ϵ2 - 1))
∂x∂ϵ1(ϵ1, ϵ2, θ, x) = -log(x)*x^ϵ1 * ∂x∂θ(ϵ1, ϵ2, θ, x)
∂x∂ϵ2(ϵ1, ϵ2, θ, x) = ∂x∂ϵ1(ϵ2, ϵ1, θ, x)

# solve using bisection
function _solve(ϵ1, ϵ2, θ)
    (ϵ1 + ϵ2 + θ) isa ForwardDiff.Dual && error("sanity check: wrong code path")
    (θ > 0 && ϵ1 > 0 && ϵ2 > 0) || throw(DomainError((; θ, ϵ1, ϵ2)))
    b = max(θ^(1/ϵ1), θ^(1/ϵ2))
    x, _ = bisection(x -> x^ϵ1 + x^ϵ2 - θ, 0, b)

function ChainRulesCore.frule((Δself, Δϵ1, Δϵ2, Δθ),
                              ::typeof(_solve), ϵ1, ϵ2, θ)
    x = solve(ϵ1, ϵ2, θ)
    Δx = (∂x∂θ(ϵ1, ϵ2, θ, x) * Δθ +   # this works with ForwardDiff.Partials …
          ∂x∂ϵ1(ϵ1, ϵ2, θ, x) * Δϵ1 + # … because they have + and * defined.
          ∂x∂ϵ2(ϵ1, ϵ2, θ, x) * Δϵ2)
    return x, Δx

# adapted from
function _solve_dual(::Type{T}, dual_args...) where {T<:Dual}
    ȧrgs = (NO_FIELDS,  partials.(dual_args)...)
    args = (_solve, value.(dual_args)...)
    y, ẏ = frule(ȧrgs, args...)
    T(y, ẏ)

function solve(ϵ1::T1, ϵ2::T2, θ::T3) where {T1,T2,T3}
    T = promote_type(T1, T2, T3)
    if T <: Dual
        _solve_dual(T, ϵ1, ϵ2, θ)
        _solve(ϵ1, ϵ2, θ)

# rudimentary checks
f(x) = solve(0.3 + x, 0.2 + x, 1.2 + x)
d1 = central_fdm(5, 1)(f, 0.0)
d2 = ForwardDiff.derivative(f, 0.0)
@show isapprox(d1, d2; atol = 1e-4)

Do you think we could make a macro to do that.
Like a @import_frule foo or something?


I haven’t found a clean way to add an AD-method for ForwardDiff.Duals, because in order to dispatch on the AD method I need a trait like the promoted type above.

Maybe some helper functions could make this easier and cleaner than a macro. Also, obtaining partials for non-duals is OK but would be wasteful for arrays and similar, where specializing on them not being duals would work better.

1 Like

Yeah, I am pretty convinced one wants to always keep the dual on the outside
A Dual of Arrays, not an Array of Duals.
Because it makes this easy, and also means you can use BLAS etc.

ForwardDiff2 did that

1 Like

If ForwardDiff2 is maybe not looking like it’s going to become a full replacement, is there anything else in the pipelines for forward-mode AD beyond ForwardDiff? I am in general a very happy use of ForwardDiff, but I sometimes have to do some hacky things that are less elegant than @Tamas_Papp’s solution here. I have a pattern I’ve ended up using a few times of writing methods like

function myfunction(x::Dual{T,V,N}) where{T,V<:AbstractFloat,N}
    # manual first derivative of my function here...

which are definitely kind of scary because I don’t really know what I’m doing and have gotten partials wrong in non-obvious ways before. I should probably start using Tamas’ method here instead, but in general I’d be curious to hear about forward-mode stuff. A lot of my problems are, say, 10 dimensional, and so the extra overhead of reverse-mode makes it less appealing.

Note that my solution above is making ForwardDiff use existing partials defined by ChainRulesCore.frule — in other words, you still need to get them right, but hopefully in one place only without duplicating code.

1 Like

Diffractor (@keno) has both forwards and reverse mode.
And it’s forward mode does use ChainRulesCore.frule natively.
(and it’s reverse does use rrule)
DIffractor has not been released yet, but I get the impression it is getting close


Hi all,

I am seeking for a “default procedure” to get ForwardDiff working together with already defined frules. The code example from @Tamas_Papp looks pretty well, but I need a more arbitrary way to solve this for array arguments. The use of array arguements results in ForwardDiff.Dual becoming Vector{ForwardDiff.Dual{...}} and T = promote_type(T1, T2, T3) results in T=any if I mix scalar and vector arguments.

I am playing around the whole day, but I can’t figure out a way to solve the problem for arbitrary functions, meaning for functions with mixed scalar and vector arguments.

Everyone with suggestions would make my day :slight_smile:
Thanks in advance!

PS: If there is a good tutorial on how to build custom AD rules for ForwardDiff for functions with multiple and/or vector arguments, it would be a nice hint, too.


This is a rather late reply, hope it is still useful. For something similar to the code above to work, you need to extract the eltype of array arguments and promote those.


Thanks for the reply.
Yes, we (prototypically) solved it in a similar way (cheching the array type and converting every element). As soon as I spend more time on this again, I will try to build and post a generic solution.

Just stumbled upon this thread, and I seem to recall that @mohamed82008 had an answer in the works

Check NonconvexUtils.jl/runtests.jl at main · JuliaNonconvex/NonconvexUtils.jl · GitHub


This works excellent!
Should be added to ForwardDiff, there is already an open issue adressing this: Automatic ChainRules compatibility · Issue #579 · JuliaDiff/ForwardDiff.jl · GitHub

Update: The macro from NonconvexUtils was moved to the light weight package GitHub - ThummeTo/ForwardDiffChainRules.jl that’s now also registered.