[ANN] ForwardDiffPullbacks.jl, ForwardDiff-based ChainRulesCore pullbacks

I’m happy to announce ForwardDiffPullbacks.jl. It’s been around for a bit, but hasn’t been quite ready to annouce until now (v0.2.1 that is).

ForwardDiffPullbacks implements pullbacks compatible with ChainRulesCore that are calculated via ForwardDiff.

The package provides fwddiff. If wrapped around a function (i.e. fwddiff(f)), it will cause ChainRules pullbacks to be calculated using ForwardDiff (i.e. by evaluating the original function with ForwardDiff.Dual numbers, possibly multiple times). The pullback (ChainRulesCore.rrule) will return a ChainRulesCore thunk for each argument of the function.

So Zygote.gradient(fwddiff(f), xs...) should yield the same result as Zygote.gradient(f, xs...), but will typically be substantially faster for a function that has a comparatively small number of arguments, especially if the function runs a deep calculation.

ForwardDiffPullbacks does come with broadcasting support, fwddiff(f).(args...) will use ForwardDiff to differentiate each iteration in the broadcast separately.

Currently, ForwardDiffPullbacks supports functions whose arguments and result(s) are statically sized, like Real, Tuple, StaticArrays.StaticArray and (nested) NamedTuples and plain structs. Dynamic arrays are not really supported yet.

Here’s an example for a situation in which explicit mixed-mode differentiation via ForwardDiffPullbacks will outperform Zygote-only differentiation significantly. Given a very simple statistical model, a parameter vector X and some random data

using Distributions, ForwardDiffPullbacks, Zygote, BenchmarkTools

model(X) = Exponential.(X)
loglike(X, data) = sum(logpdf.(model(X), data))

model_fd(X) = fwddiff(Exponential).(X)
loglike_fd(X, data) = sum(fwddiff(logpdf).(model_fd(X), data))

X = 3.0 * rand(1000)
data = rand.(model(X))

the gradients of the log-likelihood implementations loglike (Zygote-only) and loglike_fd (uses ForwardDiffPullbacks for model and likelihood calculation) are (approximately) equal:

@assert Zygote.gradient(loglike, X, data)[1] ≈ Zygote.gradient(loglike_fd, X, data)[1]
@assert Zygote.gradient(loglike, X, data)[2] ≈ Zygote.gradient(loglike_fd, X, data)[2]

The performance of Zygote alone is quite horrible, though (in cases like this, not in general):

julia> @benchmark Zygote.gradient(loglike, $X, $data)
BenchmarkTools.Trial: 1213 samples with 1 evaluation.
 Range (min … max):  3.039 ms … 16.362 ms  ┊ GC (min … max): 0.00% … 68.64%
 Time  (median):     3.767 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   4.115 ms ±  1.408 ms  ┊ GC (mean ± σ):  4.27% ±  9.60%

  ▄███████▅▅▄▄▃▃▂▃▂▂▂▂▂▂▂▁▂▁▂▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂▁▁▂▁▂▂▂▂ ▃
  3.04 ms        Histogram: frequency by time        12.3 ms <

 Memory estimate: 1.00 MiB, allocs estimate: 32612.

Zygote has Zygote.forwarddiff to force mixed-mode A/D, but it’s Zygote-specific and inconvenient to use with multi-argument functions. It also and fails in this case (and likely would be slower here as there’s not broadcast-optimization for it):

julia> model_zf(X) = (x -> Zygote.forwarddiff(Exponential, x)).(X)
julia> loglike_zf(X, data) = broadcast(q -> Zygote.forwarddiff(args -> logpdf(args...), q), zip(model_zf(X), data))
julia> Zygote.gradient(loglike_zf, X, data)
ERROR: MethodError: no method matching extract(::Exponential{ForwardDiff.Dual{Nothing, Float64, 1}})

Explicit mixed-mode auto-differentiation via ForwardDiffPullbacks performs much better:

julia> @benchmark Zygote.gradient(loglike_fd, $X, $data)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  39.857 μs …   3.668 ms  ┊ GC (min … max): 0.00% … 97.09%
 Time  (median):     46.856 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   60.861 μs ± 117.596 μs  ┊ GC (mean ± σ):  8.08% ±  4.22%

  ▁▇█▇▆▄▅▅▂▂▄▄▃▄▄▃▃▃▁▁▁▁▁ ▁▁▁▁▁▁▁                              ▂
  ███████████████████████████████▇██▇▇█▆▇█▇▇▇▇▇▇▆▅▅▅▆▆▆▅▆▅▅▅▄▅ █
  39.9 μs       Histogram: log(frequency) by time       144 μs <

 Memory estimate: 103.48 KiB, allocs estimate: 23.

ForwardDiffPullbacks is not limited to Zygote, it should work with an Julia auto-diff framework that supports ChainRulesCore.


Just to satisfy my curiosity, can you more generally explain what this package does?

I understand AD and do use ForwardDiff.jl. I guess what I don’t understand is what “ChainRules pullbacks” are.

I’ll try … (AD-specialists, please excuse the lack of mathematical rigor):

To calculate the gradient of a function g(f(x))), reverse-mode AD evaluates needs to evaluate
v_i = dg/df_i and v_i * df_i/dx_j (chain rule). So the key element is evaluating vector-Jacobian products like v_i * df_i/dx_j.

Zygote and the upcoming Diffractor do this by “rewriting” the source code of f to generate vector-Jacobian-product functions (so-called pullbacks). It’s the function returned by jvp = Zygote.pullback(f, x)[2], so jvp(v) will evaluate v_i * df_i/dx_j, but in a generalized fashion, as the arguments and return values of functions don’t need to be scalars or vectors.

Zygote’s compiler can do amazing stuff, but it comes with an overhead that can be significant for “cheap” functions, especially if they perform many calculation steps. There’s also cases that the compiler can’t handle, etc. So one can provide custom-written pullbacks (vector-Jacobian product functions) by specializing ChainRulesCore.rrule for the given function. If rrule(f, args...) is defined, Zygote and similar frameworks will use it instead of trying to derive it’s own pullback.

The ChainRulesCore docs provide an in-depths explanation of how rrule works.

AFAIK, ChainRulesCore rrules are supported by the AD-frameworks Diffractor.jl, Nabla.jl, Yota.jl and Zygote.jl (ReverseDiff.jl has indirect support for rrule via ReverseDiff.@grad_from_chainrules).


Maybe I should add:

ForwardDiff uses forward-mode AD, which performs a Jacobian-vector-product (JVP, pushforward) instead of vector-Jacobian product (VJP, pullback). But to calculate (e.g.) a gradient we need a VJP. To “emulate” a VJP, one needs to evaluate JVP’s for each argument of the function, which seems costly (and can be). But forward-mode AD has very little overhead, so for functions with few arguments and deep calculations, this can still be a lot faster than a VJP generated by Zygote & friends.

f_fd = ForwardDiffPullbacks.fwddiff(f) wraps f in a function f_fd that just forwards all arguments to f. So f_fd(args...) == f(args...). But f_fd comes with a specialization of rrule(f_fd, args...) which uses (potentially multiple) ForwardDiff passes to compute the required VJP.

1 Like

Thanks for the explanation. So this has some conceptual similarity to ReverseDiff.@forward then?

Yes, and to Zygote.forwarddiff. But in constract to them, it’s not framework-specific. And code that defines functions which may need explicit mixed-mode AD doesn’t have to take on on such heavy dependencies - ForwardDiffPullbacks is quite lightweight (on top of ForwardDiff, which is less lightweight, but still):

julia> @time using ForwardDiff
  1.101606 seconds (3.29 M allocations: 236.277 MiB, 6.13% gc time, 38.53% compilation time)

julia> @time using ForwardDiffPullbacks
  0.042586 seconds (108.04 k allocations: 6.513 MiB)

julia> @time using ReverseDiff
  3.463774 seconds (17.74 M allocations: 1.198 GiB)

julia> @time using Zygote
  2.930755 seconds (6.24 M allocations: 382.616 MiB, 10.61% gc time, 90.92% compilation time)

a few of us in the Julia world authored a paper a while back ([1810.08297] Dynamic Automatic Differentiation of GPU Broadcast Kernels) which includes some relevant background/applications w.r.t. this kind of technique, for anybody who might be interested :slight_smile:


Nice to see this registered!

I had not seen 1810.08297 somehow, thanks. As noted there, the gradient of broadcast has to decide whether to store the derivatives, or to call f again. Zygote at present chooses the former (for simple enough functions), and ForwardDiffPullbacks make it easy to choose the latter instead:

julia> f2(x) = abs2(@show x);

julia> y, bk = Zygote.pullback(x -> sum(f2.(x)), [1,2,3]);  # exactly 3 evaluations, stores dual parts until backward pass
x = Dual{Nothing}(1,1)
x = Dual{Nothing}(2,1)
x = Dual{Nothing}(3,1)

julia> bk(1.0)
([2.0, 4.0, 6.0],)

julia> y2, bk2 = Zygote.pullback(x -> sum(fwddiff(f2).(x)), [1,2,3]);  # forward pass only
x = 1
x = 2
x = 3

julia> bk2(1.0)  # 3 more evaluations for backward pass:
x = Dual{ForwardDiff.Tag{Tuple{typeof(f2), Val{1}}, Int64}}(1,1)
x = Dual{ForwardDiff.Tag{Tuple{typeof(f2), Val{1}}, Int64}}(2,1)
x = Dual{ForwardDiff.Tag{Tuple{typeof(f2), Val{1}}, Int64}}(3,1)
([2.0, 4.0, 6.0],)

Maybe a remark for newcomers to the Julia AD-scene: Zygote has good heuristics and will often switch to forward mode, esp. for broadcasts, automatically. ForwardDiffPullbacks is meant to help with cases where those heuristics “fail” (like in the example above) - Zygote is great, but it can’t do magic, and those heuristics have to keep a balance and can’t be too aggressive in regard to forward mode. I hope that’s kind of an accurate description.

1 Like

I’ve been thinking about using single pass for pullback calls with a single function argument or making this configurable (maybe something like fwddiff(f, SinglePass()) or so). Using thunks (current ForwardDiffPullbacks behavior) should be helpful esp. for multi-arg function when we finally bring Zygote PR 966 home (help very welcome). Also, some functions (lot’s of distributions-related stuff) don’t support dual numbers for all their arguments.

Making this more manually configured might be a good move, possibly it should even be opt-in within Zygote.

Zygote’s use of ForwardDiff is occasionally surprising, and can bite you by ignoring closed-over parameters in the function. An example where it manages to figure this out (and revert to a slower all-Zygote path) is this:

julia> Zygote.gradient([1,2,3], 4) do xs, y
         f3 = x -> abs2(@show(x)/y)
         sum(fwddiff(f3).(xs))  # this cannot track gradient w.r.t. y
x = 1
x = 2
x = 3
x = Dual{ForwardDiff.Tag{Tuple{var"#84#86"{Int64}, Val{1}}, Int64}}(1,1)
x = Dual{ForwardDiff.Tag{Tuple{var"#84#86"{Int64}, Val{1}}, Int64}}(2,1)
x = Dual{ForwardDiff.Tag{Tuple{var"#84#86"{Int64}, Val{1}}, Int64}}(3,1)
([0.125, 0.25, 0.375], nothing)

julia> Zygote.gradient([1,2,3], 4) do xs, y
         f3 = x -> abs2(@show(x)/y)
         sum(f3.(xs))  # reverts to slower generic broadcast, no Dual
x = 1
x = 2
x = 3
([0.125, 0.25, 0.375], -0.4375)

Should go without saying - contributions to ForwardDiffPullbacks are very welcome, of course. I’d be happy to see this develop into a community package. I’m not really an AD-specialist and expert eyes will certainly spot quite a few things that can be done better/cleaner or in a more general fashion…

That’s a very good point, people should be aware of this: ForwardDiffPullbacks does exactly the same, the tangent of the function itself is always NoTangent() currently. I’m not sure if closure arguments can be easily/safely replaced by dual numbers in a generic fashion, resp. what it would take to do that.

Thanks for the link, GitHub - jrevels/MixedModeBroadcastAD.jl is quite interesting as well! Do you think the techniques described in the paper are still relevant and could be extended to use cases like differentiating fused GPU broadcasts?

This video was very helpful to me on explaining Automatic Differentiation: