The zygote landing page has the example
@adjoint add(a, b) = add(a, b), Δ -> (Δ, Δ)
Can someone explain what the Δ -> (Δ, Δ)
means? I would have expected the derivatives to be 1,1.
The zygote landing page has the example
@adjoint add(a, b) = add(a, b), Δ -> (Δ, Δ)
Can someone explain what the Δ -> (Δ, Δ)
means? I would have expected the derivatives to be 1,1.
The @adjoint
macro is used to define your own backward pass. That is, given the derivative of any L
with respect to add(a, b)
, the @adjoint
above defines how to calculate the derivatives of L
with respect to a
and b
.
If you use Zygote’s gradient
to take the gradient of add(1, 2)
, you would get (1, 1)
as you pointed out. Zygote simply uses the backward pass you define and plug 1
for Δ
.
I do not understand this yet.
How does it know to plugin in 1 for Δ
, rather than some other value?
If it already knows the derivative is 1, there is no need to define it using @adjoint!
This is just to explain what I don’t understand - I know that I do not (understand).
Δ -> (Δ, Δ)
is essentially defining \frac{\partial L}{\partial \text{add(a, b)}} \rightarrow \left( \frac{\partial L}{\partial \text{a}}, \frac{\partial L}{\partial \text{b}} \right). If we want to calculate the gradient of add(a, b)
wrt to a
and b
, then L
would be add(a, b)
and Δ
is therefore 1. Generally speaking, gradient
is just syntatic sugar that will lower down to the adjoint and plug 1 in.This example is rather trivial, but customizing the adjoint offers incredible flexibility. You can make Zygote work with your own data type, or use mathematical knowledge to aid automatic differentiation.
@lhnguyen-vn’s answer is correct but let me expand on the last answer a little bit. When you define a custom adjoint, the idea is that you should then be able to plug that custom adjoint into a larger computation. Let’s say for example that we have f(a,b) = add(a,b) and we want to take the gradient of
g(f(a,b)), where g is an R->R function. The gradient of g is (dg/df * df/da, dg/df * df/db). When you define the adjoint function for add(a,b), what it needs to compute is (dg/df * df/da, dg/df * df/db), given dg/df as input. So in your example, Δ → (Δ,Δ) is equivalent to Δ → (Δ * 1, Δ * 1). where the 1s are df/da and df/db. In the case that you’re just computing the gradient of f, then you just pass Δ=1, because df/df=1. However, if you’re computing the gradient of g, then you need to pass dg/df into your adjoint function for f.
Apologies for not prettying up my equations by the way.