Gumbel Softmax in Julia

Can someone please point me to some implementation of gumbel softmax in julia or an equivalent of the following function?
https://pytorch.org/docs/master/tensors.html#torch.Tensor.scatter_

Isn’t the Gumbel softmax trick just the argmax of the sum of the log-probabilities and samples from the Gumbel distribution?

You should be able to something like

julia> using Distributions

julia> p = 0.8
0.8

julia> lps = log.([1.0 - p, p])
2-element Array{Float64,1}:
 -1.6094379124341005
 -0.2231435513142097

julia> ixs = map(_ -> argmax(lps .+ rand(Gumbel(), 2)), 1:10^6);

julia> mean(ixs .- 1)
0.799612

You need softmax in place of argmax. There are some details here:
https://casmls.github.io/general/2017/02/01/GumbelSoftmax.html

There are implementations of softmax in Flux, Knet, or NNLib, or @Joshua_Bowles has a straightforward implementation with a nice description here:

EDIT: Neither of these give direct access to the log-density; it would be really nice to have this directly available in Distributions.jl

2 Likes

Thanks! I got it working!

1 Like

Great! Is this in a public repo? I’d love to check it out