Differentiable argmin? Trying VQ-VAE in Flux.jl

Isn’t the issue rather that the PyTorch example does a further matmul with self._embedding.weight (which will propagate a gradient signal back) whereas the Julia one does not? I wouldn’t think PyTorch’s scatter is differentiable wrt indices either.

1 Like