ANN: XDiff.jl - an expression differentiation package

Kind of. ReverseDiff is an operator-overloading implementation of AD, so it intercepts function calls, not expressions. It can, in fact, intercept and differentiate any broadcast call as a single primitive by differentiating the broadcast kernel using forward-mode (an example where mixed-mode AD really shines). Fusion here actually is a performance benefit, because it reduces the overall number of primitives in the graph.

Unfortunately, automatic application of this technique can be unsafe due to potential perturbation confusion. Thus, ReverseDiff uses overly-strict dispatch to ensure that this technique only gets applied on “known” kernels (e.g. *, +, ^, etc.) or kernels which the user has annotated with @forward (meaning, “I’ll allow ReverseDiff to use forward-mode AD for this function”).

Since ForwardDiff now has a tagging system (i.e. provides confusion-safe dual numbers), however, we can loosen this restriction, and apply the technique automatically for any broadcast call! In general, this should be more efficient than unfusing the calls and differentiating them separately.

3 Likes