Differentiable argmin? Trying VQ-VAE in Flux.jl

I’m trying to follow this VQ-VAE tutorial in Julia, particularly the VectorQuantizer module:

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

It looks like Pytorch is able to differentiate through torch.argmin, while Zygote in my Julia implementation can’t:

emb = Flux.Embedding(args[:emb_dim], args[:num_embeddings]; init=Flux.glorot_uniform) |> gpu
ps = Flux.params(emb)

loss, grad = withgradient(ps) do
    distances = sum(flat_input.^ 2, dims=2)' .+ sum(emb.weight .^ 2, dims=2) + 2.0f0 * emb.weight * flat_input'

    encoding_indices = argmin(distances, dims=1)

    encodings = NNlib.scatter(+, zeros(size(z_flat, 1), num_embeddings)' |> gpu, enc_inds; dstsize=(num_embeddings, size(flat_input, 1)))

    sum(encodings)
end

where the grads w.r.t. ps returns nothing

I understand that argmin isn’t AD-friendly. Is there a way I could implement this vector quantization step differentiably?

argmin is not the problem here. scatter is non-differentiable wrt encoding_indices NNlib.jl/scatter.jl at c0b4b8b6e969422ff4af18b473d02192b27c9cf4 · FluxML/NNlib.jl · GitHub

julia> using Zygote, NNlib

julia> gradient((src, idx) -> sum(NNlib.scatter(+, src, idx)), [10, 100], [1, 3])
([1.0, 1.0], nothing)
1 Like

Isn’t the issue rather that the PyTorch example does a further matmul with self._embedding.weight (which will propagate a gradient signal back) whereas the Julia one does not? I wouldn’t think PyTorch’s scatter is differentiable wrt indices either.

1 Like

Yes, I think so too. I did not mean to say the scatter implementation is incorrect, just that argmin was not causing the problem.

1 Like

Thank you, I missed that!