Hello all,
I have been trying to replicate a simple example (Fig 2) presented in the paper: Black-box stochastic variational inference in five lines of Python (https://www.cs.toronto.edu/~duvenaud/papers/blackbox.pdf). The python code can be accessed via a direct link here. https://github.com/HIPS/autograd/blob/master/examples/black_box_svi.py.
I have just translated all the code from Python to Julia; and basically wanted to use Zygote instead of Autograd.
function gaussian_entropy(log_std)
dim = length(log_std)
return 0.5 * dim * (1.0+ log(2*pi)) + sum(log_std)
end
function unpack_params(params, dim)
μ = params[1:dim]
logσ = params[(dim+1):end]
return μ, logσ
end
function log_density(x)
mu = x[1,:]
log_sigma = x[2,:]
sigma_density = logpdf.(Normal(0, 1.35), log_sigma)
mu_density = logpdf.(Normal.(0, exp.(log_sigma)), mu)
return sigma_density .+ mu_density
end
function variational_objective(logprob, params, nsample, dim)
μ, logσ = unpack_params(params, dim)
samples = randn(dim, nsample) .* exp.(logσ) .+ μ
lower_bound = gaussian_entropy(logσ) + mean(logprob(samples))
return -lower_bound
end
ws = [-1., -1.0, -5.0, -5.0]
mc_ = 5000
dim_ = 2
loss() = variational_objective(log_density, ws, mc_, dim_)
θ = Flux.params(ws)
grads = gradient(() -> loss(), θ)
opt = ADAM(0.1, (0.9, 0.999))
for t in 1:2000
Flux.Optimise.update!(opt, ws, grads[ws])
end
It seems grads[ws] initially (for t=1) returns the correctly estimated gradient but later in the loop the gradients are wrong. I have tried to replace grads[ws] with ForwardDiff.gradient; that works fine.
I have also tried to use pullback within the loop (the code used in the VAE example), which works fine. However, it becomes excruciatingly slow.
loss2, back = Flux.pullback(Flux.params(ws)) do
variational_objective(log_density_, ws, 5000, 2)
end
grad = back(1f0)
Flux.Optimise.update!(opt, ws, grad[ws])
I am very new to Flux and Zygote. So it is quite likely that I have done something silly in the replication. Any help or suggestion will be really appreciated.
To make the code complete, the following code was used to plot the target density and approximating density. The converging parameters are pretty off.
begin
x₁ = range(-2, stop=2, length=151)
x₂ = range(-4, stop=2, length=151)
μ_, logσ_=unpack_params(ws, 2)
plot(x₁,x₂, (x, y) -> exp(log_density([x, y])[1]), st=:contour)
plot!(x₁,x₂, (x, y) -> pdf(MvNormal(μ_, Diagonal(exp.(2 .*logσ_))), [x,y]), st=:contour)
end
Best regards,
Lei