Zygote: meaning of @adjoint add(a, b) = add(a, b), Δ -> (Δ, Δ)

The zygote landing page has the example

@adjoint add(a, b) = add(a, b), Δ -> (Δ, Δ)

Can someone explain what the Δ -> (Δ, Δ) means? I would have expected the derivatives to be 1,1.

The @adjoint macro is used to define your own backward pass. That is, given the derivative of any L with respect to add(a, b), the @adjoint above defines how to calculate the derivatives of L with respect to a and b.
If you use Zygote’s gradient to take the gradient of add(1, 2), you would get (1, 1) as you pointed out. Zygote simply uses the backward pass you define and plug 1 for Δ.

1 Like

I do not understand this yet.

  1. How does it know to plugin in 1 for Δ, rather than some other value?

  2. If it already knows the derivative is 1, there is no need to define it using @adjoint!

This is just to explain what I don’t understand - I know that I do not (understand).

  1. Δ -> (Δ, Δ) is essentially defining \frac{\partial L}{\partial \text{add(a, b)}} \rightarrow \left( \frac{\partial L}{\partial \text{a}}, \frac{\partial L}{\partial \text{b}} \right). If we want to calculate the gradient of add(a, b) wrt to a and b, then L would be add(a, b) and Δ is therefore 1. Generally speaking, gradient is just syntatic sugar that will lower down to the adjoint and plug 1 in.
  2. Defining the adjoint is the reason how Zygote knows the derivative! Substituting Δ with 1 gives you the derivatives (1, 1).

This example is rather trivial, but customizing the adjoint offers incredible flexibility. You can make Zygote work with your own data type, or use mathematical knowledge to aid automatic differentiation.

@lhnguyen-vn’s answer is correct but let me expand on the last answer a little bit. When you define a custom adjoint, the idea is that you should then be able to plug that custom adjoint into a larger computation. Let’s say for example that we have f(a,b) = add(a,b) and we want to take the gradient of
g(f(a,b)), where g is an R->R function. The gradient of g is (dg/df * df/da, dg/df * df/db). When you define the adjoint function for add(a,b), what it needs to compute is (dg/df * df/da, dg/df * df/db), given dg/df as input. So in your example, Δ → (Δ,Δ) is equivalent to Δ → (Δ * 1, Δ * 1). where the 1s are df/da and df/db. In the case that you’re just computing the gradient of f, then you just pass Δ=1, because df/df=1. However, if you’re computing the gradient of g, then you need to pass dg/df into your adjoint function for f.

Apologies for not prettying up my equations by the way.

2 Likes