As many of the folks here know, there’s a lot of ambitious work being done on various Julia AD tools these days. A key component of each AD tool’s implementation is the mechanism or framework it utilizes for defining, querying, and executing “differentiation rules” (also referred to as “primitives”).
For a long while now, the DiffRules package (originally derived from the symbolic differentiation rules within Calculus.jl) has served as a common dependency for this purpose. DiffRules is extremely limited in scope: it only supports scalar real-to-real derivative rules on expressions. At this point, it’s apparent that the Julia AD world would benefit from a common rule framework that supports far more:
first-class complex differentiation
custom perturbation/sensitivity propagation
linear algebraic and general array primitives
mixed-mode composability of rule definitions
function-based (as opposed to expression-based) rule specification
decoupling rule specification and input/output value specialization
While pursuing my own work on Capstan, I’ve cooked up an initial design and implementation for such a package: ChainRules.
This package is definitely a WIP, but I figured it’d be good to get eyes on it early. The framework/design is essentially there, but there are only a few toy rules right now, a bunch of TODOs, virtually no tests, etc. PRs welcome! Documentation is incoming, which should help if you’d like to contribute.
The package’s design is heavily inspired by various conversations with (and work done by) Will Tebbutt, @ssfrr, @MikeInnes, @denizyuret, and Ekin Akyürek; my hope is that with some elbow grease we can make ChainRules useful to all these folks!
I got confused by the use of \otimes, which means Cartesian product here if I understand correctly? It conflicts with the standard notation of using it for tensor products.
A standard notation is K to mean either R or C, maybe it’s useful here, so most rules just have to be written once for the real and complex case (in the holomorphic case, the rules are the same for R and C, with an extra conjugate in the adjoint mode)
The @sig notation is just convenience syntax for when the signature is simple; it’s not too important as long as we keep it consistent and useful. The underlying (very, very tiny) markup language for the signatures is quite important, though, since that’s the thing rule authors will generally be interacting with.
Hmm. I’m down to add syntax K in the future if we have a need for it; you can already express this by writing the signature manually (without @sig).
However, I’m not sure it makes sense to define rules with less specific domains, since complex rules generally require a different output shape than real rules. For those holomorphic cases where complex rules reduce to real rules, it’s already easy to write well-specified, generic fallbacks by simply composing the rules. We can expand/add more such fallbacks to cover more cases, if we want (and e.g. add guards to whitelist/blacklist functions).
I think it’ll be okay for ChainRules to also provide rules for (and thus have a package dependency on) some noncontroversial non-stdlib packages (e.g. SpecialFunctions.jl).
Sure! Other packages implementing domain-specific kernels can depend on ChainRules and add whatever rules they want to opt-in to the rest of the ecosystem. Alternatively, if e.g. “DomainSpecificKernels.jl” wants to support ChainRules, but doesn’t want to depend on it by default, a separate “DSKChainRules.jl” adapter package could be created to hold the rule definitions.
Looks good, I’m looking forward to playing around with it some more. Me and Will were only just discussing adding thunks to Flux’s adjoint API for the same reason as ChainRules (if I’m understanding correctly), so it’s pleasing to see the convergence there.
This is obviously a fair bit more complex than DiffRules; it would be nice to understand the motivation for the signature interface this has, as well as having some simple usage examples (e.g. how do I get a list of rules, how is best to wrap a rule into a Flux-style adjoint).
and only use (R → R) in the case of non-holomorphic functions?
Rules that apply in the complex domain always apply in the real domain (assuming that the function maps R to R), so it should be easy to write a fallback rule in that direction, versus having to use a whitelist system going the other way (or duplicating rules).
Yes, there are several ways to implement the fallbacks without requiring new mechanisms. I’m still trying to decide, though, between going the route you described vs. adding a @holomorphic annotation/trait (e.g. so that downstream AD tools can compute whether a non-primitive is holomorphic without runtime checks, enabling a few optimizations IIUC).
So, we definitely don’t (and aren’t going to) require duplicating rules for holomorphic functions, but I’m not sure I see the benefit yet of having a syntax that merges real/complex rules when they don’t reduce to each other otherwise. The current system already allows composition when it’s useful for specific rules. I’m open to proposals, though!