Yup that was me.
Automatic differentiation is symbolic differentiation where instead of using substitution you use assignment (=
).
It’s not 100% correct, but it gets you pretty close that it’s a good rule of thumb. For example, if you have sin(f(x))
, then with symbolic differentiation you get sin'(f(x))f'(x)
and you evaluate that expression. But with automatic differentiation you generate a code that effectively does:
fx = f(x)
dfx = f'(x)
sinx = sin(fx)
dsinx = cos(fx)
return dsinx * dfx
If f
is what’s known as a primitive, then f'(x)
has been defined and you use that. If it hasn’t been defined in the AD system, then you look into its code and do this same process to all steps, etc.
But if you take what AD gives you and instead of having it be in different operations, if you just substitute everything to build a single final expression in the end, then you get sin'(f(x))f'(x)
or the symbolic derivative expression.
Thinking about Differentiation of Languages
One way of describing this then is that symbolic differentiation is limited by the semantics of “standard mathematical expressions”, and AD is simply rewriting it in a language that allows for assignment. AD is symbolic differentiation in the language of SSA IR, i.e. computer code. So in a sense I think it’s fine to say Zygote is doing symbolic differentiation on Julia code.
When we say “symbolic differentiation”, we normally mean that it is differentiating in the language of mathematical expressions, i.e. you take Symbolics.jl and use @variable x; f(x)
what it will do is generate a mathematical expression without any computational aspects and then perform the differentiation in the context of the purely mathematical language:
using Symbolics
@variables x
function f(x)
out = one(x)
for i in 1:5
out *= x^i
end
out
end
sin(f(x)) # sin(x^15)
Evaluation with symbolic variables completely removes the “non-mathematical” computational expressions, and then we symbolically differentiate in this language:
Symbolics.derivative(sin(f(x)),x) # 15(x^14)*cos(x^15)
Note that expression blow up: we take an entire computational expression and squash it down to a single mathematical formula and differentiate it, which then has the problem that you can exponential blow up in the size of the expressions you’re building/differentiating. This is the downside of symbolic differentiation.
function f(x)
out = x
for i in 1:5
out *= sin(out)
end
out
end
sin(f(x)) # sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))))
Symbolics.derivative(sin(f(x)),x) # (sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*cos(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*(sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))) + x*cos(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))) + x*(sin(x)*sin(x*sin(x)) + x*cos(x)*sin(x*sin(x)) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x)))*sin(x)*sin(x*sin(x))*cos(x*sin(x)*sin(x*sin(x))) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*cos(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*(sin(x)*sin(x*sin(x)) + x*cos(x)*sin(x*sin(x)) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x)))*sin(x)*sin(x*sin(x))*cos(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))) + x*(sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))) + x*cos(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))) + x*(sin(x)*sin(x*sin(x)) + x*cos(x)*sin(x*sin(x)) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x)))*sin(x)*sin(x*sin(x))*cos(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))) + x*(sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))) + x*cos(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))) + x*(sin(x)*sin(x*sin(x)) + x*cos(x)*sin(x*sin(x)) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x)))*sin(x)*sin(x*sin(x))*cos(x*sin(x)*sin(x*sin(x))) + x*(x*cos(x) + sin(x))*sin(x)*cos(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*cos(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))))*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*cos(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))))*cos(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x)))*sin(x*sin(x)*sin(x*sin(x))*sin(x*sin(x)*sin(x*sin(x))))))
So then a good way to think about AD is that it’s doing differentiation directly on the language of computer programs. When doing this, you want to build expressions that carry forward the derivative calculation, and generate something that is a computation of the derivative, not a mathematical expression of it.
On that same example, this looks like:
function f(x)
out = x
for i in 1:5
# sin(out) => chain rule sin' = cos
tmp = (sin(out[1]), out[2] * cos(out[1]))
# out = out * tmp => product rule
out = (out[1] * tmp[1], out[1] * tmp[2] + out[2] * tmp[1])
end
out
end
function outer(x)
# sin(x) => chain rule sin' = cos
out1, out2 = f(x)
sin(out1), out2 * cos(out1)
end
dsinfx(x) = outer((x,1))[2]
f((1,1)) # (0.01753717849708632, 0.36676042682811677)
dsinfx(1) # 0.3667040292067162
See this vs
julia> substitute(sin(f(x)),x=>1)
0.017536279576682495
julia> substitute(Symbolics.derivative(sin(f(x)),x),x=>1)
0.3667040292067162
You can see the symbolic aspects in there: it uses the analytical derivative of sin
being cos
, and it uses the product rule in the code it generates. Those are the primitives. But you then use an extra variable to accumulate the derivative, because again you’re working in the language of computer programs with for
loops and all, and you are taking the derivative of a computational expression to get another computational expression.
The advantage of course is that things like control flow which have no simple representation in mathematical language have a concise computational description, and you can avoid ever building the exponentially large mathematical expressions that is being described.