I have been trying to replicate a simple example (Fig 2) presented in the paper: Black-box stochastic variational inference in five lines of Python (https://www.cs.toronto.edu/~duvenaud/papers/blackbox.pdf). The python code can be accessed via a direct link here. autograd/black_box_svi.py at master · HIPS/autograd · GitHub.
I have just translated all the code from Python to Julia; and basically wanted to use Zygote instead of Autograd.
function gaussian_entropy(log_std) dim = length(log_std) return 0.5 * dim * (1.0+ log(2*pi)) + sum(log_std) end function unpack_params(params, dim) μ = params[1:dim] logσ = params[(dim+1):end] return μ, logσ end function log_density(x) mu = x[1,:] log_sigma = x[2,:] sigma_density = logpdf.(Normal(0, 1.35), log_sigma) mu_density = logpdf.(Normal.(0, exp.(log_sigma)), mu) return sigma_density .+ mu_density end function variational_objective(logprob, params, nsample, dim) μ, logσ = unpack_params(params, dim) samples = randn(dim, nsample) .* exp.(logσ) .+ μ lower_bound = gaussian_entropy(logσ) + mean(logprob(samples)) return -lower_bound end
ws = [-1., -1.0, -5.0, -5.0] mc_ = 5000 dim_ = 2 loss() = variational_objective(log_density, ws, mc_, dim_) θ = Flux.params(ws) grads = gradient(() -> loss(), θ) opt = ADAM(0.1, (0.9, 0.999)) for t in 1:2000 Flux.Optimise.update!(opt, ws, grads[ws]) end
It seems grads[ws] initially (for t=1) returns the correctly estimated gradient but later in the loop the gradients are wrong. I have tried to replace grads[ws] with ForwardDiff.gradient; that works fine.
I have also tried to use pullback within the loop (the code used in the VAE example), which works fine. However, it becomes excruciatingly slow.
loss2, back = Flux.pullback(Flux.params(ws)) do variational_objective(log_density_, ws, 5000, 2) end grad = back(1f0) Flux.Optimise.update!(opt, ws, grad[ws])
I am very new to Flux and Zygote. So it is quite likely that I have done something silly in the replication. Any help or suggestion will be really appreciated.
To make the code complete, the following code was used to plot the target density and approximating density. The converging parameters are pretty off.
begin x₁ = range(-2, stop=2, length=151) x₂ = range(-4, stop=2, length=151) μ_, logσ_=unpack_params(ws, 2) plot(x₁,x₂, (x, y) -> exp(log_density([x, y])), st=:contour) plot!(x₁,x₂, (x, y) -> pdf(MvNormal(μ_, Diagonal(exp.(2 .*logσ_))), [x,y]), st=:contour) end