I’m happy to announce ForwardDiffPullbacks.jl. It’s been around for a bit, but hasn’t been quite ready to annouce until now (v0.2.1 that is).

ForwardDiffPullbacks implements pullbacks compatible with ChainRulesCore that are calculated via ForwardDiff.

The package provides `fwddiff`

. If wrapped around a function (i.e. `fwddiff(f)`

), it will cause ChainRules pullbacks to be calculated using ForwardDiff (i.e. by evaluating the original function with `ForwardDiff.Dual`

numbers, possibly multiple times). The pullback (`ChainRulesCore.rrule`

) will return a ChainRulesCore thunk for each argument of the function.

So `Zygote.gradient(fwddiff(f), xs...)`

should yield the same result as `Zygote.gradient(f, xs...)`

, but will typically be substantially faster for a function that has a comparatively small number of arguments, especially if the function runs a deep calculation.

ForwardDiffPullbacks does come with broadcasting support, `fwddiff(f).(args...)`

will use ForwardDiff to differentiate each iteration in the broadcast separately.

Currently, ForwardDiffPullbacks supports functions whose arguments and result(s) are statically sized, like `Real`

, `Tuple`

, `StaticArrays.StaticArray`

and (nested) `NamedTuple`

s and plain structs. Dynamic arrays are not really supported yet.

Here’s an example for a situation in which explicit mixed-mode differentiation via ForwardDiffPullbacks will outperform Zygote-only differentiation significantly. Given a very simple statistical model, a parameter vector `X`

and some random `data`

```
using Distributions, ForwardDiffPullbacks, Zygote, BenchmarkTools
model(X) = Exponential.(X)
loglike(X, data) = sum(logpdf.(model(X), data))
model_fd(X) = fwddiff(Exponential).(X)
loglike_fd(X, data) = sum(fwddiff(logpdf).(model_fd(X), data))
X = 3.0 * rand(1000)
data = rand.(model(X))
```

the gradients of the log-likelihood implementations `loglike`

(Zygote-only) and `loglike_fd`

(uses ForwardDiffPullbacks for model and likelihood calculation) are (approximately) equal:

```
@assert Zygote.gradient(loglike, X, data)[1] ≈ Zygote.gradient(loglike_fd, X, data)[1]
@assert Zygote.gradient(loglike, X, data)[2] ≈ Zygote.gradient(loglike_fd, X, data)[2]
```

The performance of Zygote alone is quite horrible, though (in cases like this, not in general):

```
julia> @benchmark Zygote.gradient(loglike, $X, $data)
BenchmarkTools.Trial: 1213 samples with 1 evaluation.
Range (min … max): 3.039 ms … 16.362 ms ┊ GC (min … max): 0.00% … 68.64%
Time (median): 3.767 ms ┊ GC (median): 0.00%
Time (mean ± σ): 4.115 ms ± 1.408 ms ┊ GC (mean ± σ): 4.27% ± 9.60%
▁█▇▄▂
▄███████▅▅▄▄▃▃▂▃▂▂▂▂▂▂▂▁▂▁▂▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂▁▁▂▁▂▂▂▂ ▃
3.04 ms Histogram: frequency by time 12.3 ms <
Memory estimate: 1.00 MiB, allocs estimate: 32612.
```

Zygote has `Zygote.forwarddiff`

to force mixed-mode A/D, but it’s Zygote-specific and inconvenient to use with multi-argument functions. It also and fails in this case (and likely would be slower here as there’s not broadcast-optimization for it):

```
julia> model_zf(X) = (x -> Zygote.forwarddiff(Exponential, x)).(X)
julia> loglike_zf(X, data) = broadcast(q -> Zygote.forwarddiff(args -> logpdf(args...), q), zip(model_zf(X), data))
julia> Zygote.gradient(loglike_zf, X, data)
ERROR: MethodError: no method matching extract(::Exponential{ForwardDiff.Dual{Nothing, Float64, 1}})
```

Explicit mixed-mode auto-differentiation via ForwardDiffPullbacks performs much better:

```
julia> @benchmark Zygote.gradient(loglike_fd, $X, $data)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 39.857 μs … 3.668 ms ┊ GC (min … max): 0.00% … 97.09%
Time (median): 46.856 μs ┊ GC (median): 0.00%
Time (mean ± σ): 60.861 μs ± 117.596 μs ┊ GC (mean ± σ): 8.08% ± 4.22%
▁▇█▇▆▄▅▅▂▂▄▄▃▄▄▃▃▃▁▁▁▁▁ ▁▁▁▁▁▁▁ ▂
███████████████████████████████▇██▇▇█▆▇█▇▇▇▇▇▇▆▅▅▅▆▆▆▅▆▅▅▅▄▅ █
39.9 μs Histogram: log(frequency) by time 144 μs <
Memory estimate: 103.48 KiB, allocs estimate: 23.
```

ForwardDiffPullbacks is not limited to Zygote, it should work with an Julia auto-diff framework that supports ChainRulesCore.