Taking Complex Autodiff Seriously in ChainRules

The pullback of a function in Zygote is a function \mathcal{B_f}(v) = J_f^\dagger * v where v is a vector, J_f is the Jacobian of the function f and ^\dagger is the Hermitian conjugate.

Zygote will treat a struct with n fields as if it were a vector of length n, so the complex number 1 + 0im is treated like the vector [1 0] and 0 + 1im is treated like the vector [0, 1] in the above Jacobian vector product. The vector returned from this product is then converted into a struct

That is, in this case \mathcal{B}_f(1) = J_f^\dagger * [1, 0] and \mathcal{B}_f(i) = J_f ^\dagger * [0, 1]. The Jacobian here is 2x2, so this function returns a vector with two elements which are then interpreted as the real and imaginary parts of a complex number.

This is how Zygote deals with all structs, and it’s quite beautiful imo.

2 Likes