Zygote @adjoint with matrices


I would like to use Zygote, because of the amazing possibility to tag the parameters you want to optimize and let Zygote do the rest.
However in my case I am creating a matrix (K) given a (potentially highly nested) list of parameters (theta) and passing them to a function (f) to get a scalar.
I have derived analytically the gradient df/dtheta = g(dK/dtheta) which is non-linear (and contains a share of optimization tricks) and Zygote works perfectly for differentiating K. My initial solution was to compute dK/dp, for each theta and pass it to df/dtheta but it is very inefficient/unpractical.
Now I want to write an @adjoint that would contain g(K) but I have no idea how to go about it since it’s not a jacobian-vector product anymore…

Here is a simple example with the derivations

How can I write an appropriate adjoint for this?

In Zygote, the pullback, which maps the previous jacobian to the new jacobian, is just an arbitrary function, so it doesn’t necessarily have to be a jacobian-vector product. It should be as easy as:

@adjoint f(K) = f(K), J -> (g(J),)
1 Like

The problem is that when I do this J is a scalar.

You’re right, Zygote does reverse-mode differentiation, so the argument to the pullback of f is actually df/df, which is just one. In your case, I would suggest looking into forward-mode AD using ForwardDiff instead because it should be much more efficient for differentiating K and it will be easier to implement this custom adjoint for f.

The only problem is that I need the implicit differentiation of Zygote :sweat_smile:
I need to rely on Zygote.params

You can use ForwardDiff within Zygote with the function forwarddiff. See also here in the Zygote docs.

1 Like

Thanks but this is not compatible with the Params approach :slight_smile:

I finally found a work around!
Since my concern was optimizing the gradients computations (avoiding precomputed inverses etc), I simply wrote a new function whose gradient is equivalent to the one I want!
It slightly less efficient but works pretty well for now!

1 Like