What's the state of Automatic Differentiation in Julia January 2023?

Not exactly. There is a trivial extension of ChainRules to Enzyme rules, which is to ignore activity and assume all variables are active.

Let me describe a bit more detail on why Enzyme rules are a bit more interesting, why they’re harder, and why it will lead to performance improvements. Take a look at an Enzyme call I wrote yesterday (Segfault with constant variables? "Enzyme cannot deduce type"? · Issue #571 · EnzymeAD/Enzyme.jl · GitHub):

...

function heat_eq!(dx,x,p,t)       
    time = t/Tf;
    u = input_signal(time, p)
    
    diffusion_x!(dx,x,Nx,1,Δx)
  
    dx .= α * dx
    dx[1] = dx[1] + α/(λ * Δx) * u
end

Enzyme.autodiff(heat_eq!, Duplicated(dx, d_dx), Duplicated(x, d_x), 
                           Duplicated(p, d_p), Enzyme.Const(t));

This shows the Enzyme “activity states”. Duplicated means that dx is a variable which is to be differentiated, and its derivative will be written into d_dx. This allows Enzyme to be fully non-allocating when differentiating arrays. And note the mutation support. However, here I didn’t want to differentiate with respect to t, so I Enzyme.Const(t).

Zygote/Diffractor work by differentiating all code and hoping dead code elimination is good enough to eliminate branches. ChainRules kind of supports something around activity states by using @thunk, but the AD systems right now ignore the thunks and expand them most of the time anyways, so it kind of doesn’t exist (at least in the code generation perspective). Enzyme is able to produce the code in a way that is specialized to the differentiation of only some components. And there are many different activity states:

https://enzyme.mit.edu/julia/api/#Types-and-constants

Thus in order for your rule to be well-defined, you need to define it for all combinations of activity states. For example, a function f(x,y) can have (Duplicated, Const), (Const, Duplicated), etc. and you want a rule for every combination. Doesn’t that lead to a combinatorial explosion of required rules?

Yes. 6^n overloads are thus required for a full definition with v1 (Add support for user-defined rules by vchuravy · Pull Request #177 · EnzymeAD/Enzyme.jl · GitHub).

But of course there are many different fallbacks you can do. You could setup a system for example where if you have a version that is non-const array, you fall back to Duplicated, if it’s a number or struct you fall back to Active, and so then the number of rules decrease. And then ChainRules defines the “always active” versions. This would then allow for ChainRules to be used to give default overloads, which could then be improved with additional overloads on specific activity states. ChainRulesCore could adopt the activity state types as well and then it would map over better (and then things like Diffractor would be able to use that information as well).

So tl;dr, ChainRules doesn’t give enough information to fully optimize, Enzyme is asking for too much, so what’s holding back the rules system is some kind of fallback mechanism so that you don’t need to define 700 dispatches for every rule. When such a fallback mechanism exists, then ChainRules should be supported, though sub-optimal.

20 Likes