Zygote documentation on custom adjoints states dx/dy (in the notation from the docs) is the Jacobian when x, y are vectors. I am struggling in relating this to an implementation.
For a function f : R^N -> R^M modeled with f = Chain(Dense(N, H, tanh), Dense(H, M)) with hidden dimension H, I intuitively thought dy/dx would be an MxN matrix, but as we all know intuition often fails so I experimented a bit. Running
using Flux, Zygote
N, M, H = 2, 3, 16
g = Chain(Dense(N, H, tanh), Dense(H, M))
x = randn(N, 1)
y, back = Zygote.pullback(g, x)
back([])
gives DimensionMismatch("matrix A has dimensions (16,3), matrix B has dimensions (0,1)") so it looks like dy/dx is an MxH matrix. Okay, cool, I guess this is the first matrix multiplication in the first back call. But do I even get an MxN matrix if I compute the back pass? No, running back(ones(size(y)...)) returns an Nx1 matrix. So how would I compute the “MxN” Jacobian?
I am clearly not thinking about this the right way. Could someone give me some pointers to this problem please? Reading the documentation has unfortunately not made it “click” for me.