Does Flux.jl
have an equivalent to rsample
in PyTorch
that automatically implements these stochastic/policy gradients. That way the reparameterized sample becomes differentiable.
I don’t think so, but @BatyLeo and I have been working on something like that. Maybe it’s time to open source it?
+1 I would also find use for something like this. A package would be nice.
In the meantime maybe StochasticAD.jl can be useful? What do you need this for?
I’m diving into Flux.jl for my research. My objective is to define a loss function that measures the convergence of the model towards a uniform data distribution, so to speak. I generate K random fictitious observations and compare how many of them are smaller than the true data in the training set. In other words, the model generates K simulations, and we determine the number of simulations where the generated data is smaller than the actual data. If the model is well trained, this distribution should converge to a uniform distribution. However, since the resulting histogram is not differentiable, I needed to approximate this idea using compositions of continuous differentiable functions (I did it).
The reason is that I believe this generates a random node, and therefore the need to apply the reparametrization trick arises from there.
Is your model generating a discrete distribution or a continuous one? The trouble here is that not every distribution is easily amenable to reparametrization. There are extensions but they too have limits.
What we are implementing with @BatyLeo is closer to the score function method, which is more generic but also suffers from high variance.
See this paper for a great overview:
I think that the reason, why Julia lacks the rsample
is that it is ridiculously simple to implement for distributions of interest and there is not a great need for it. To implement the classical Gaussian reparametrization, which covers most uses is effectively one line of code. I am adding a complete example, but it is effectively this m.μ(x) .+ m.σ(x) .* r
, which is nicely similar to what is in papers.
using Flux
using Functors
struct Model{S,M}
μ::M
σ::S
end
@functor Model
function (m::Model)(x)
T = eltype(x)
r = randn(T, 2, size(x,2))
m.μ(x) .+ m.σ(x) .* r
end
m = Model(
Chain(Dense(2,2,relu), Dense(2,2)),
Chain(Dense(2,2,relu), Dense(2,2,softplus)),
)
x = randn(Float32, 2, 11)
gradient(m -> sum(m(x)), m)
For several distributions this is true. Yet, torch also supports some distributions which are less trivial and even then, it would be a nice addition that I missed some times.
Update: I have put together a little package for differentiating through expectations. It includes both the REINFORCE and the reparametrization trick. Still very experimental, and currently being registered. I’d be excited to have your feedback