Reparametrization trick in Flux.jl

Does Flux.jl have an equivalent to rsample in PyTorch that automatically implements these stochastic/policy gradients. That way the reparameterized sample becomes differentiable.

1 Like

I don’t think so, but @BatyLeo and I have been working on something like that. Maybe it’s time to open source it?

3 Likes

+1 I would also find use for something like this. A package would be nice.

1 Like

In the meantime maybe StochasticAD.jl can be useful? What do you need this for?

2 Likes

I’m diving into Flux.jl for my research. My objective is to define a loss function that measures the convergence of the model towards a uniform data distribution, so to speak. I generate K random fictitious observations and compare how many of them are smaller than the true data in the training set. In other words, the model generates K simulations, and we determine the number of simulations where the generated data is smaller than the actual data. If the model is well trained, this distribution should converge to a uniform distribution. However, since the resulting histogram is not differentiable, I needed to approximate this idea using compositions of continuous differentiable functions (I did it).

The reason is that I believe this generates a random node, and therefore the need to apply the reparametrization trick arises from there.

1 Like

Is your model generating a discrete distribution or a continuous one? The trouble here is that not every distribution is easily amenable to reparametrization. There are extensions but they too have limits.
What we are implementing with @BatyLeo is closer to the score function method, which is more generic but also suffers from high variance.
See this paper for a great overview:

I think that the reason, why Julia lacks the rsample is that it is ridiculously simple to implement for distributions of interest and there is not a great need for it. To implement the classical Gaussian reparametrization, which covers most uses is effectively one line of code. I am adding a complete example, but it is effectively this m.μ(x) .+ m.σ(x) .* r , which is nicely similar to what is in papers.

using Flux
using Functors


struct Model{S,M}
	μ::M 
	σ::S
end

@functor Model

function (m::Model)(x)
	T = eltype(x)
	r = randn(T, 2, size(x,2))
	m.μ(x) .+ m.σ(x) .* r 
end

m = Model( 
	Chain(Dense(2,2,relu), Dense(2,2)),
	Chain(Dense(2,2,relu), Dense(2,2,softplus)),
	)

x = randn(Float32, 2, 11)

gradient(m -> sum(m(x)), m)
4 Likes

For several distributions this is true. Yet, torch also supports some distributions which are less trivial and even then, it would be a nice addition that I missed some times.

1 Like