I’m trying to follow this VQ-VAE tutorial in Julia, particularly the VectorQuantizer module:
# Calculate distances
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
It looks like Pytorch is able to differentiate through torch.argmin
, while Zygote in my Julia implementation can’t:
emb = Flux.Embedding(args[:emb_dim], args[:num_embeddings]; init=Flux.glorot_uniform) |> gpu
ps = Flux.params(emb)
loss, grad = withgradient(ps) do
distances = sum(flat_input.^ 2, dims=2)' .+ sum(emb.weight .^ 2, dims=2) + 2.0f0 * emb.weight * flat_input'
encoding_indices = argmin(distances, dims=1)
encodings = NNlib.scatter(+, zeros(size(z_flat, 1), num_embeddings)' |> gpu, enc_inds; dstsize=(num_embeddings, size(flat_input, 1)))
sum(encodings)
end
where the grads w.r.t. ps returns nothing
I understand that argmin
isn’t AD-friendly. Is there a way I could implement this vector quantization step differentiably?