Does `Flux.jl`

have an equivalent to `rsample`

in `PyTorch`

that automatically implements these stochastic/policy gradients. That way the reparameterized sample becomes differentiable.

I don’t think so, but @BatyLeo and I have been working on something like that. Maybe it’s time to open source it?

+1 I would also find use for something like this. A package would be nice.

I’m diving into Flux.jl for my research. My objective is to define a loss function that measures the convergence of the model towards a uniform data distribution, so to speak. I generate K random fictitious observations and compare how many of them are smaller than the true data in the training set. In other words, the model generates K simulations, and we determine the number of simulations where the generated data is smaller than the actual data. If the model is well trained, this distribution should converge to a uniform distribution. However, since the resulting histogram is not differentiable, I needed to approximate this idea using compositions of continuous differentiable functions (I did it).

The reason is that I believe this generates a random node, and therefore the need to apply the reparametrization trick arises from there.

Is your model generating a discrete distribution or a continuous one? The trouble here is that not every distribution is easily amenable to reparametrization. There are extensions but they too have limits.

What we are implementing with @BatyLeo is closer to the score function method, which is more generic but also suffers from high variance.

See this paper for a great overview:

I think that the reason, why Julia lacks the `rsample`

is that it is ridiculously simple to implement for distributions of interest and there is not a great need for it. To implement the classical Gaussian reparametrization, which covers most uses is effectively one line of code. I am adding a complete example, but it is effectively this `m.μ(x) .+ m.σ(x) .* r `

, which is nicely similar to what is in papers.

```
using Flux
using Functors
struct Model{S,M}
μ::M
σ::S
end
@functor Model
function (m::Model)(x)
T = eltype(x)
r = randn(T, 2, size(x,2))
m.μ(x) .+ m.σ(x) .* r
end
m = Model(
Chain(Dense(2,2,relu), Dense(2,2)),
Chain(Dense(2,2,relu), Dense(2,2,softplus)),
)
x = randn(Float32, 2, 11)
gradient(m -> sum(m(x)), m)
```

For several distributions this is true. Yet, torch also supports some distributions which are less trivial and even then, it would be a nice addition that I missed some times.