Gradient Origin Networks with Zygote/Flux?

Hi! I’d like to train an autoencoder according to this paper. Briefly, it requires first taking a gradient step to get an estimate of the latent vector z, and then evaluating the output of the decoder using z. The difference between this and predictive/sparse coding is that you should keep the gradient of the loss wrt the model params from the z evaluation step when evaluating the loss after Decoder(z). In the Pytorch implementation, this looks like this


   # compute the gradients of the inner loss with respect to zeros (gradient origin)
    z = torch.zeros(batch_size, nz, 1, 1).to(device).requires_grad_()
    g = F(z)
    inner_loss = ((g - x)**2).sum(1).mean()
    grad = torch.autograd.grad(inner_loss, [z], create_graph=True, retain_graph=True)[0]
    z = (-grad)

    # now with z as our new latent points, optimise the data fitting loss
    g = F(z)
    outer_loss = ((g - x)**2).sum(1).mean()
    optim.zero_grad()
    outer_loss.backward()
    optim.step()

Is it possible to do this with Zygote?

Thanks!

For better or worse, the only way to know is to try. In theory it should be possible to do this kind of thing (GANs with gradient penalties are another example that has come up before) by nesting gradient/pullback with itself or some other autodiff, but success is highly dependent on what operations are being used (many are not second derivative friendly for Zygote + Zygote).