Can someone please point me to some implementation of gumbel softmax in julia or an equivalent of the following function?

https://pytorch.org/docs/master/tensors.html#torch.Tensor.scatter_

# Gumbel Softmax in Julia

Isn’t the Gumbel softmax trick just the argmax of the sum of the log-probabilities and samples from the Gumbel distribution?

You should be able to something like

```
julia> using Distributions
julia> p = 0.8
0.8
julia> lps = log.([1.0 - p, p])
2-element Array{Float64,1}:
-1.6094379124341005
-0.2231435513142097
julia> ixs = map(_ -> argmax(lps .+ rand(Gumbel(), 2)), 1:10^6);
julia> mean(ixs .- 1)
0.799612
```

You need *softmax* in place of *argmax*. There are some details here:

https://casmls.github.io/general/2017/02/01/GumbelSoftmax.html

There are implementations of softmax in Flux, Knet, or NNLib, or @Joshua_Bowles has a straightforward implementation with a nice description here:

EDIT: Neither of these give direct access to the log-density; it would be really nice to have this directly available in Distributions.jl

2 Likes

Thanks! I got it working!

1 Like

Great! Is this in a public repo? I’d love to check it out