Hi! I’d like to train an autoencoder according to this paper. Briefly, it requires first taking a gradient step to get an estimate of the latent vector z
, and then evaluating the output of the decoder using z
. The difference between this and predictive/sparse coding is that you should keep the gradient of the loss wrt the model params from the z
evaluation step when evaluating the loss after Decoder(z)
. In the Pytorch implementation, this looks like this
# compute the gradients of the inner loss with respect to zeros (gradient origin)
z = torch.zeros(batch_size, nz, 1, 1).to(device).requires_grad_()
g = F(z)
inner_loss = ((g - x)**2).sum(1).mean()
grad = torch.autograd.grad(inner_loss, [z], create_graph=True, retain_graph=True)[0]
z = (-grad)
# now with z as our new latent points, optimise the data fitting loss
g = F(z)
outer_loss = ((g - x)**2).sum(1).mean()
optim.zero_grad()
outer_loss.backward()
optim.step()
Is it possible to do this with Zygote?
Thanks!