Can someone please point me to some implementation of gumbel softmax in julia or an equivalent of the following function?
https://pytorch.org/docs/master/tensors.html#torch.Tensor.scatter_
1 Like
Isn’t the Gumbel softmax trick just the argmax of the sum of the log-probabilities and samples from the Gumbel distribution?
You should be able to something like
julia> using Distributions
julia> p = 0.8
0.8
julia> lps = log.([1.0 - p, p])
2-element Array{Float64,1}:
-1.6094379124341005
-0.2231435513142097
julia> ixs = map(_ -> argmax(lps .+ rand(Gumbel(), 2)), 1:10^6);
julia> mean(ixs .- 1)
0.799612
You need softmax in place of argmax. There are some details here:
There are implementations of softmax in Flux, Knet, or NNLib, or @Joshua_Bowles has a straightforward implementation with a nice description here:
EDIT: Neither of these give direct access to the log-density; it would be really nice to have this directly available in Distributions.jl
2 Likes
Thanks! I got it working!
1 Like
Great! Is this in a public repo? I’d love to check it out
1 Like