ChainRules: Replacing DiffRules in the Julia AD world

Hi!

As many of the folks here know, there’s a lot of ambitious work being done on various Julia AD tools these days. A key component of each AD tool’s implementation is the mechanism or framework it utilizes for defining, querying, and executing “differentiation rules” (also referred to as “primitives”).

For a long while now, the DiffRules package (originally derived from the symbolic differentiation rules within Calculus.jl) has served as a common dependency for this purpose. DiffRules is extremely limited in scope: it only supports scalar real-to-real derivative rules on expressions. At this point, it’s apparent that the Julia AD world would benefit from a common rule framework that supports far more:

  • first-class complex differentiation
  • custom perturbation/sensitivity propagation
  • linear algebraic and general array primitives
  • mixed-mode composability of rule definitions
  • function-based (as opposed to expression-based) rule specification
  • decoupling rule specification and input/output value specialization

While pursuing my own work on Capstan, I’ve cooked up an initial design and implementation for such a package: ChainRules.

This package is definitely a WIP, but I figured it’d be good to get eyes on it early. The framework/design is essentially there, but there are only a few toy rules right now, a bunch of TODOs, virtually no tests, etc. PRs welcome! Documentation is incoming, which should help if you’d like to contribute.

The package’s design is heavily inspired by various conversations with (and work done by) Will Tebbutt, @ssfrr, @MikeInnes, @denizyuret, and Ekin Akyürek; my hope is that with some elbow grease we can make ChainRules useful to all these folks!

Best,
Jarrett

31 Likes

Sounds very good! The code is more complex and boilerplatey than DiffRules, might be useful to have “simple mode” DiffRules-style macros to make it easier for people to contribute.

1 Like

Looking forward to that, this package looks like it solves a lot of issues that limited DiffRules.

Would porting rules from DiffRules help?

1 Like

Great idea! I’ve just added @forward_rule and @reverse_rule to make it easier to port over simple real-domain rules, e.g.:

@forward_rule(R → R, sin(x), cos(x))
@forward_rule(R⊗R → R, *(x, y), (y, x))

@reverse_rule([R] → R, sum(x), ȳ, ȳ)
@reverse_rule([R]⊗[R] → [R], *(x, y), z̄, z̄ * y', x' * z̄)

Definitely :slight_smile:

5 Likes

Wonderful!

I got confused by the use of \otimes, which means Cartesian product here if I understand correctly? It conflicts with the standard notation of using it for tensor products.

A standard notation is K to mean either R or C, maybe it’s useful here, so most rules just have to be written once for the real and complex case (in the holomorphic case, the rules are the same for R and C, with an extra conjugate in the adjoint mode)

The @sig notation is just convenience syntax for when the signature is simple; it’s not too important as long as we keep it consistent and useful. The underlying (very, very tiny) markup language for the signatures is quite important, though, since that’s the thing rule authors will generally be interacting with.

Good point, let’s tilt it by a few degrees :wink:

Hmm. I’m down to add syntax K in the future if we have a need for it; you can already express this by writing the signature manually (without @sig).

However, I’m not sure it makes sense to define rules with less specific domains, since complex rules generally require a different output shape than real rules. For those holomorphic cases where complex rules reduce to real rules, it’s already easy to write well-specified, generic fallbacks by simply composing the rules. We can expand/add more such fallbacks to cover more cases, if we want (and e.g. add guards to whitelist/blacklist functions).

2 Likes

Now those are some pretty macros.

This sounds like a good idea.

Should this only contain rules for functions in Base? Would it be possible for other packages to use this, to provide derivative definitions in a flexible way?

I think it’ll be okay for ChainRules to also provide rules for (and thus have a package dependency on) some noncontroversial non-stdlib packages (e.g. SpecialFunctions.jl).

Sure! Other packages implementing domain-specific kernels can depend on ChainRules and add whatever rules they want to opt-in to the rest of the ecosystem. Alternatively, if e.g. “DomainSpecificKernels.jl” wants to support ChainRules, but doesn’t want to depend on it by default, a separate “DSKChainRules.jl” adapter package could be created to hold the rule definitions.

1 Like

Sounds good, thanks!

Hmm… That one usually denotes the direct sum of subspaces of a vector space, whereas R is just a set. I guess it’s still better than tensor product alternative…

We are indeed using it as the direct sum here.

Looks good, I’m looking forward to playing around with it some more. Me and Will were only just discussing adding thunks to Flux’s adjoint API for the same reason as ChainRules (if I’m understanding correctly), so it’s pleasing to see the convergence there.

This is obviously a fair bit more complex than DiffRules; it would be nice to understand the motivation for the signature interface this has, as well as having some simple usage examples (e.g. how do I get a list of rules, how is best to wrap a rule into a Flux-style adjoint).

5 Likes

Personally i prefer simple \times (×) instead of \otimes ().

7 Likes

For this usage plain times seems standard.

8 Likes

At least in math, × and actually refer to the same operation for ChainRules’ usage (finite operands). Seems like Discourse wants ×, though, so I’ve changed it to ×.

Original motivation for not picking × was because I thought it looked too similar to x. However, I now realize x doesn’t parse as an infix operator anyway. Hopefully that’ll avoid potential confusion.

3 Likes

Wouldn’t it be better to write rules such as

@forward_rule(C → C, sin(x), cos(x))

and only use (R → R) in the case of non-holomorphic functions?

Rules that apply in the complex domain always apply in the real domain (assuming that the function maps R to R), so it should be easy to write a fallback rule in that direction, versus having to use a whitelist system going the other way (or duplicating rules).

Or, ideally, a syntax that allows writing both real to real and complex to complex cases as a single rule, as this is the most common.

Yes, there are several ways to implement the fallbacks without requiring new mechanisms. I’m still trying to decide, though, between going the route you described vs. adding a @holomorphic annotation/trait (e.g. so that downstream AD tools can compute whether a non-primitive is holomorphic without runtime checks, enabling a few optimizations IIUC).

So, we definitely don’t (and aren’t going to) require duplicating rules for holomorphic functions, but I’m not sure I see the benefit yet of having a syntax that merges real/complex rules when they don’t reduce to each other otherwise. The current system already allows composition when it’s useful for specific rules. I’m open to proposals, though!

OK, possibly I was overthinking it.