Gumbel Softmax in Julia

Can someone please point me to some implementation of gumbel softmax in julia or an equivalent of the following function?

Isn’t the Gumbel softmax trick just the argmax of the sum of the log-probabilities and samples from the Gumbel distribution?

You should be able to something like

julia> using Distributions

julia> p = 0.8

julia> lps = log.([1.0 - p, p])
2-element Array{Float64,1}:

julia> ixs = map(_ -> argmax(lps .+ rand(Gumbel(), 2)), 1:10^6);

julia> mean(ixs .- 1)

You need softmax in place of argmax. There are some details here:

There are implementations of softmax in Flux, Knet, or NNLib, or @Joshua_Bowles has a straightforward implementation with a nice description here:

EDIT: Neither of these give direct access to the log-density; it would be really nice to have this directly available in Distributions.jl


Thanks! I got it working!

1 Like

Great! Is this in a public repo? I’d love to check it out