Does Zygote differentiate symbolically?

Yup that was me.

Automatic differentiation is symbolic differentiation where instead of using substitution you use assignment (=).

It’s not 100% correct, but it gets you pretty close that it’s a good rule of thumb. For example, if you have sin(f(x)), then with symbolic differentiation you get sin'(f(x))f'(x) and you evaluate that expression. But with automatic differentiation you generate a code that effectively does:

fx = f(x)
dfx  = f'(x)
sinx = sin(fx)
dsinx  = cos(fx)
return dsinx * dfx 

If f is what’s known as a primitive, then f'(x) has been defined and you use that. If it hasn’t been defined in the AD system, then you look into its code and do this same process to all steps, etc.

But if you take what AD gives you and instead of having it be in different operations, if you just substitute everything to build a single final expression in the end, then you get sin'(f(x))f'(x) or the symbolic derivative expression.

Thinking about Differentiation of Languages

One way of describing this then is that symbolic differentiation is limited by the semantics of “standard mathematical expressions”, and AD is simply rewriting it in a language that allows for assignment. AD is symbolic differentiation in the language of SSA IR, i.e. computer code. So in a sense I think it’s fine to say Zygote is doing symbolic differentiation on Julia code.

When we say “symbolic differentiation”, we normally mean that it is differentiating in the language of mathematical expressions, i.e. you take Symbolics.jl and use @variable x; f(x) what it will do is generate a mathematical expression without any computational aspects and then perform the differentiation in the context of the purely mathematical language:

using Symbolics
@variables x
function f(x)
    out = one(x) 
    for i in 1:5
        out *= x^i
sin(f(x)) # sin(x^15)

Evaluation with symbolic variables completely removes the “non-mathematical” computational expressions, and then we symbolically differentiate in this language:

Symbolics.derivative(sin(f(x)),x) # 15(x^14)*cos(x^15)

Note that expression blow up: we take an entire computational expression and squash it down to a single mathematical formula and differentiate it, which then has the problem that you can exponential blow up in the size of the expressions you’re building/differentiating. This is the downside of symbolic differentiation.

function f(x)
    out = x
    for i in 1:5
        out *= sin(out)
sin(f(x)) # sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))))
Symbolics.derivative(sin(f(x)),x) # (sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*cos(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*(sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))) + x*cos(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))) + x*(sin(x)*sin(x*sin(x)) + x*cos(x)*sin(x*sin(x)) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x)))*sin(x)*sin(x*sin(x))*cos(x*sin(x)*sin(x*sin(x))) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*cos(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*(sin(x)*sin(x*sin(x)) + x*cos(x)*sin(x*sin(x)) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x)))*sin(x)*sin(x*sin(x))*cos(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*(sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))) + x*cos(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))) + x*(sin(x)*sin(x*sin(x)) + x*cos(x)*sin(x*sin(x)) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x)))*sin(x)*sin(x*sin(x))*cos(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))) + x*(sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))) + x*cos(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))) + x*(sin(x)*sin(x*sin(x)) + x*cos(x)*sin(x*sin(x)) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x)))*sin(x)*sin(x*sin(x))*cos(x*sin(x)*sin(x*sin(x))) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*cos(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))))*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*cos(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))))*cos(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))))

So then a good way to think about AD is that it’s doing differentiation directly on the language of computer programs. When doing this, you want to build expressions that carry forward the derivative calculation, and generate something that is a computation of the derivative, not a mathematical expression of it.

On that same example, this looks like:

function f(x)
    out = x
    for i in 1:5
        # sin(out) => chain rule sin' = cos
        tmp = (sin(out[1]), out[2] * cos(out[1])) 
        # out = out * tmp => product rule
        out = (out[1] * tmp[1], out[1] * tmp[2] + out[2] * tmp[1])
function outer(x)
    # sin(x) => chain rule sin' = cos
    out1, out2 = f(x)
    sin(out1), out2 * cos(out1)
dsinfx(x) = outer((x,1))[2]

f((1,1)) # (0.01753717849708632, 0.36676042682811677)
dsinfx(1) # 0.3667040292067162

See this vs

julia> substitute(sin(f(x)),x=>1)

julia> substitute(Symbolics.derivative(sin(f(x)),x),x=>1)

You can see the symbolic aspects in there: it uses the analytical derivative of sin being cos, and it uses the product rule in the code it generates. Those are the primitives. But you then use an extra variable to accumulate the derivative, because again you’re working in the language of computer programs with for loops and all, and you are taking the derivative of a computational expression to get another computational expression.

The advantage of course is that things like control flow which have no simple representation in mathematical language have a concise computational description, and you can avoid ever building the exponentially large mathematical expressions that is being described.


(forward-mode AD, at least; reverse-mode looks a little different)

Yeah, I guess the last piece to note in that monologue is that now that you’re generating computer programs for the derivative, the exact way that you do this is no longer unique. There are many different computer programs that will compute the derivative. Two commonly used ones are forward-mode (which I wrote out by hand above) and reverse mode (which builds a code that runs forward and then runs the derivative calculation code in reverse). But there’s also 2^n many different codes that are mixtures of the two. I just wrote down one example.

That is actually a very high-fact-density sentence. That would have been dynamite information. Much more informative than that “I am not familiar with the concept of AD” (no shit, sherlock, I just asked how it works). I got a lot of references to scientific articles and other resources, but I estimate that it took me more than 500 sentences before I was sure even about this part (and nowhere did it transpire that the notion of hybrid systems, or whatever upset whomever became upset, was “haram”). I read them all and this is also what I based my evidently uninformative table on. This was quite informative, but hardly a very friendly way to communicate this information.

This is a very good explanation! Kudos for Chris Rackauckas!

Another way to say it is that zygote looks at the code which computes your function, and figures out how to write computer code that computes the value of the derivative of that function at a given point. It does this by knowing how to compute the derivative of primitive functions like multiplication, addition or sqrt etc, and how to combine those things together. It operates on syntax trees of Julia code, not on symbolic representations of mathematics.


Maybe tangential to the original question, but I’ll leave it here in any case:

This piece by Chris Olah is the clearest explanation of forward-mode and backward-mode automatic differentiation I’ve come across. Here, a computation or a function is represented as graphs with each node being a (binary) operation. This, to me, seems very different from the dual numbers approach.

It inspired me to implement this. And in doing so, I realized that once you differentiate along each edge in the graph (which would be done during the “forward pass”?), forward-mode and backward mode differentiation look quite similar. And, you can even differentiate intermediary calculations with respect to other intermediary calculations. That ability might be useful if we want to understand what nodes/operations are responsible for the largest/smallest changes in the ultimate output.

(I’m not a computer scientist and I just started learning Julia. This may be nothing new and you’ll probably cringe at my code. Still thought I’d share.)


Technically, this definition of the derivative is exactly what AD is about. i.e. f’(x) = (f(x+delta)-f(x))/delta, which holds as an identity for delta an infinitesimal. Inifinitesimals can be modeled with dual numbers. Anyway the formula can be written as f(x+delta) = f(x) + f’(x)delta. A dual number is precisely a pair (a,b) which models the number a + b*delta, in much the same way as complex numbers are modeled as pairs. So AD with dual numbers just extends basic operations to such pairs, i.e. to dual numbers. This is easy to do with Julia’s type system. To differentiate a function f in x, just evaluate f((x, 1)), i.e. f(x+delta). Since the function uses basic operations which have been extended to dual numbers, it will work. The value is a pair which equals (f(x), f’(x)), i.e. f(x) + f’(x)delta, whence the derivative f’(x) can be extracted.