Can someone please point me to some implementation of gumbel softmax in julia or an equivalent of the following function?
https://pytorch.org/docs/master/tensors.html#torch.Tensor.scatter_
              
              
              1 Like
            
            
          Isn’t the Gumbel softmax trick just the argmax of the sum of the log-probabilities and samples from the Gumbel distribution?
You should be able to something like
julia> using Distributions
julia> p = 0.8
0.8
julia> lps = log.([1.0 - p, p])
2-element Array{Float64,1}:
 -1.6094379124341005
 -0.2231435513142097
julia> ixs = map(_ -> argmax(lps .+ rand(Gumbel(), 2)), 1:10^6);
julia> mean(ixs .- 1)
0.799612
You need softmax in place of argmax. There are some details here:
There are implementations of softmax in Flux, Knet, or NNLib, or @Joshua_Bowles has a straightforward implementation with a nice description here:
EDIT: Neither of these give direct access to the log-density; it would be really nice to have this directly available in Distributions.jl
              
              
              2 Likes
            
            
          Thanks! I got it working!
              
              
              1 Like
            
            
          Great! Is this in a public repo? I’d love to check it out
              
              
              1 Like