I’m going to chime in and offer a different perspective on this. Enzyme lead, so obvious biases and will use Enzyme for some examples since I have more experience with that.
I think the overall workflow is similar to what you describe, with some key differences
User Code/Applications → Derivative Rules → AD Framework → AD-based algorithms/operators
-
First we have user code/applications, which is the code you want to to apply some derivative-based algorithm to. Perhaps a neural network model, or a scientific code. Depending on the restrictions of the downstream tools your code will need to have certain properties. For example, some code isn’t legal to write in Julia due to the Julia language / compiler (e.g. unsafe GC, python syntax, etc). More commonly, you’ll see restrictions from the other part of the workflow. If Zygote is your AD framework you can’t use mutation. Similarly if ChainRules is a rule provider, you cannot change the values of any arrays between the forward pass and the pullback – even if the tool supports mutation [more on this later]! And similarly, perhaps your AD-based algorithm assumes all values are positive or real, so your code needs to have this property as well.
-
Derivative rules. The most widely known set in Julia is ChainRules. It’s a great project by some great project, but it has its limitations. Many AD tools, like Enzyme, Zygote, etc have their own rules system as well which better adapts to the needs of the AD tool. This allows one to define more efficient operations and guarantee that the rules conform with the language semantics. Many tools (including Enzyme) support the ability to import rule definitions from ChainRules, instead of using it internal definitions, or its EnzymeRules interface (see Enzyme.jl/src/Enzyme.jl at a21e60ddbdab383f7caba147514afc95a8bdb150 · EnzymeAD/Enzyme.jl · GitHub).
So why don’t we just use ChainRules for everything?
Part of it is mutation like you say, but it is somewhat deeper. Enzyme also supports the ability to avoid unnecessary differentiation, if a value is not requested by the user, or otherwise can be proven to be zero. By default many chain rules will densely compute all values (making code take significantly longer).
Moreover, use of ChainRules within Enzyme (or another mutation-aware AD tool) may often result in incorrect code, even if the function the rule is defined on is mutation free.
For example consider the following function and custom rule (I haven’t run so consider it pseudo code with syntax).
using ChainRules
sumsq(x) = sum(x .* x)
function ChainRules.rrule(::typeof(sumsq), x)
Y = sumsq(x)
function sumsq_pullback(Ȳ)
return ChainRules.NoTangent(), 2 .* x .* Ȳ
end
return Y, sumsq_pullback
end
_, pb = ChainRules.rrule(sumsq, [2.0, 3.0])
@show pb(1.0)
(ChainRulesCore.NoTangent(), [4.0, 6.0])
We compute sum of square, so naturally the derivative is 2 * x * dout, which is what our pullback returns. What happens, however, if we mutate x before running the pullback? For example.
function sumsqzero(x)
res = sumsq(x)
x .= 0
return res
end
If we were to use ChainRules in an reverse-mode AD tool (doesn’t matter which), if it doesn’t crash you’ll get something similar to the followinf:
function sumsqzero(x, dout)
res, sumsq_pb = ChainRules.rrule(x, [2.0, 3.0])
x .= 0
grad_set_x_zero() # defined by AD tool
return sumsq_pb(dout)
end
The value of x from our original program has been overwritten, we’ll be multiplying by zero instead. Since a lot of ChainRules based AD’s assuming nothing can mutate (including outside of the Chain rule itself), this isn’t a problem. But it does mean that most of the chain rules out there are actually wrong. A fix for this case, would be defining a rule like the following.
function ChainRules.rrule(::typeof(sumsq), x)
Y = sumsq(x)
copy_x = copy(x)
function sumsq_pullback(Ȳ)
return ChainRules.NoTangent(), 2 .* copy_x .* Ȳ
end
return Y, sumsq_pullback
end
Now even if x is overwritten by subsequent code, we’ll get the correct result!
Of course if a lot of a rules defined by the system are wrong/would cause bugs for a tool because they have a different set of assumptions, it doesn’t make sense to import wrong things – especially if it also lacks other things like activity(hence Enzyme preferring its own internal rules).
Of course we want to avoid this gotcha to users, so if you read the Enzyme docs for import_rrule you’ll see more information that lets us avoid this in some cases to make auto importing easy.
import_rrule(::fn, tys...)
Automatically import a ChainRules.rrule as a custom reverse mode EnzymeRule. When called in batch mode, this
will end up calling the primal multiple times which results in slower code. This macro assumes that the underlying
function to be imported is read-only, and returns a Duplicated or Const object. This macro also assumes that the
inputs permit a .+= operation and that the output has a valid Enzyme.make_zero function defined. It also assumes
that overwritten(x) accurately describes if there is any non-preserved data from forward to reverse, not just
the outermost data structure being overwritten as provided by the specification.
Finally, this macro falls back to almost always caching all of the inputs, even if it may not be needed for the
derivative computation.
As a result, this auto importer is also likely to be slower than writing your own rule, and may also be slower
than not having a rule at all.
Use with caution.
Enzyme.@import_rrule(typeof(Base.sort), Any);
- AD Frameworks (like Zygote, Enzyme, ReverseDiff, ForwardDiff, etc). Like you say these have been touched on elsewhere so I won’t discuss much more. However, as a slightly biased update we just finished adding significant improvements to handle most (though not all) type unstable code, blas including some cublas (but not lapack), nicer error messages, and more. Still a lot to go but if you look at my github history of multiple commits a day, including weekends, you’ll see that we’re slow but steadily getting things easier to use and faster to run.
These tend to provide low level utilities (for example Enzyme.autodiff which handles arbitrary argument count, differentiating with respect to some but not other variables, etc), as well as high level utilities (like Enzyme.gradient and Enzyme.jacobian. or Zygote.gradient / ForwardDiff.gradient).
I’m also going to lump in the interface packages, like AbstractDifferentiation.jl and DifferentiationInterface.jl in here. The AD packages all have slightly different conventions for how to call their high and low level API’s so these packages intend to make it easy to try out one tool versus the other.
However, I will caution that this is not free. For example, every year there are numerous papers that mention trying to improve “abstraction without regret” (Google Scholar , including some of my papers say for automatically improving tensor arithmetic (https://arxiv.org/pdf/1802.04730), for example ).
Fundamentally, any interface (be it for differentiation or something else) will need to decide what type of API to provide to users. A higher level API can make it easier to use (but may miss optimizations and thus result in slower code), or provide access to lower level utilities which give the users both the power and the burden of more effective usage. This disconnect is actually why most AD tools provide both (e.g. pushforward/pullback and gradient for ForwardDiff/ReverseDiff-like tools and autodiff / gradient / jacobian for Enzyme-like ones).
The interface packages have a definite purpose in exploration of AD as well as making it easier to use, but I will offer a word of caution before blindly applying them. As an example, whereas in the past AbstractDifferentiation.jl was quite successful in exploring what AD tools one can use – it wasn’t adopted in some large applications due to issues of its Tuple handling. I can’t find the link to it offhand, but I remember a discussion where a big application wanted to switch to AbstractDifferentiation.jl, but when compared with using Zyogte.jl directly their tuple handling caused significant slowdowns making machine learning or whatever the domain was unviable. DifferentiationInterface is a great package which is trying to lift some of the AbstractDifferentiation.jl limitations, but it similarly comes with its own set of limitations. For example, it (presently) assumes that inputs are either scalars/arrays (making things difficult for direct use in ML models), and does not have support for either functions with multiple arguments, or enabling/disabling differentiating with respect to some variables. All of these are serious (but potentially resolvable), limitations. For example, some code which could be efficiently differentiated with a tool directly, might be significantly slower when using DI, or possibly not be possible to differentiate at all! Same with AbstractDifferentiation.jl. It’s a bit out of scope here, but Switch to DifferentiationInterface by gdalle · Pull Request #29 · tpapp/LogDensityProblemsAD.jl · GitHub describes some issues with the creation of closures (which may be created by some of the interface packages).
To be clear both Guillaume and Adrian on the DI side are working dilligently to identify which of these limitations can be lifted, but I think its critical that we don’t put the cart before the horse here avoid a situation like what happened with Diffractor in which it was expected to solve everybody’s problems (and then couldn’t fulfill these goals). In that regard, one should definitely follow both DifferentiationInterface.jl and AbstractDifferentiation.jl closely, but ultimately select the tool that matches their use case best (which can be changed as things develop). For example, core sciml libraries like SciMLSensitivity and Optimization.jl as well as ML libs are explicitly not using DI.jl yet because it doesn’t have support for multi-arguments / support for structured data, and doesn’t make sense to risk a lack of support for user code / drop performance as a result. Other users like ContinuousNormalizingFlows.jl just take a gradient of a single-argument function which takes an array, which matches the DI abstraction perfectly!
The same can be said of Enzyme.jl as well. It supports a specific set of code whose features have been growing over time. Any time it comes up I try to caution both what is expected to be supported, and similarly as it has grown more feature support, it has been growing in adoption to places that make sense (for example, with recent BLAS and other support we’ve now been added as a AD-backend within ML libraries Flux/Lux). At it’s start it only had support for GPUCompiler-compatible code, that looks like a kernel and definitely nothing with a GC-able object, type unstable, or more. It’s grown signficiantly over time, adding support for parallelism, forward mode, batching, GC, type instabilities, BLAS, and most recently cuBLAS.
- Derivative-based algorithms/optimizations. This is your standard backprop or bayesian inference, etc. Plenty has been said about this elsewhere so I’ll leave it for now, but these similarly have their own set of limitations and expectations (often along the lines that it is legal to rerun a code multiple times – perhaps preventing someone from using a function which reads in from a file or increments a counter).
As a side note @stevengj and @RS-Coop getting back to your original question of an HVP, since this seems sufficiently useful for folks, I went ahead and just tagged a release of Enzyme with a helper hvp function (as well as in place hvp! and even hvp_and_gradient!)
using Enzyme
using LinearAlgebra
f(x) = norm(x)^3 * x[end] + x[1]^2 # example ℝⁿ→ℝ function
x = collect(Float64, 1:5)
v = collect(Float64, 6:10)
@show Enzyme.hvp(f, x, v)
# 5-element Vector{Float64}:
# 1164.881764812144
# 1749.5486430921133
# 2346.2155213720825
# 2942.882399652052
# 6431.86668789933
using Zygote, ForwardDiff
function Hₓ(x, v)
∇f(y) = Zygote.gradient(f, y)[1]
return ForwardDiff.derivative(α -> ∇f(x + α*v), 0)
end
@show Hₓ(x, v)
# 5-element Vector{Float64}:
# 1164.881764812144
# 1749.5486430921133
# 2346.2155213720825
# 2942.8823996520514
# 6431.86668789933
If you look at the definition it’s pretty simple and similar to @stevengj’s one from above (with some minor changes to try getting better perf). If you want it to have a version whic supports more arguments, or has a constant input, you’d just pass those to the corresponding autodiff’s.
@inline function gradient_deferred!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F}
make_zero!(dx)
autodiff_deferred(Reverse, f, Active, Duplicated(x, dx))
dx
end
@inline function hvp!(res::X, f::F, x::X, v::X) where {F, X}
grad = make_zero(x)
Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v))
return nothing
end
@inline function hvp(f::F, x::X, v::X) where {F, X}
res = make_zero(x)
hvp!(res, f, x, v)
return res
end