Automatic/Black-box variational inference replication vs Autograd

Hello all,

I have been trying to replicate a simple example (Fig 2) presented in the paper: Black-box stochastic variational inference in five lines of Python (https://www.cs.toronto.edu/~duvenaud/papers/blackbox.pdf). The python code can be accessed via a direct link here. https://github.com/HIPS/autograd/blob/master/examples/black_box_svi.py.

I have just translated all the code from Python to Julia; and basically wanted to use Zygote instead of Autograd.

function gaussian_entropy(log_std)
	dim = length(log_std)
	return 0.5 * dim * (1.0+ log(2*pi)) + sum(log_std)
end

function unpack_params(params, dim)
	μ = params[1:dim]
	logσ = params[(dim+1):end]
	return μ, logσ
end

function log_density(x)
	mu = x[1,:]
	log_sigma = x[2,:]
	sigma_density = logpdf.(Normal(0, 1.35), log_sigma)
	mu_density = logpdf.(Normal.(0, exp.(log_sigma)), mu)
	return sigma_density .+ mu_density
end

function variational_objective(logprob, params, nsample, dim)
	μ, logσ = unpack_params(params, dim)
	samples = randn(dim, nsample) .* exp.(logσ) .+ μ
	lower_bound = gaussian_entropy(logσ) + mean(logprob(samples))
	return -lower_bound
end
    ws = [-1., -1.0, -5.0, -5.0]
	mc_ = 5000
	dim_ = 2
	loss() = variational_objective(log_density, ws, mc_, dim_)
	θ = Flux.params(ws)
	grads = gradient(() -> loss(), θ)
	opt = ADAM(0.1, (0.9, 0.999))
	for t in 1:2000
		Flux.Optimise.update!(opt, ws, grads[ws])
	end

It seems grads[ws] initially (for t=1) returns the correctly estimated gradient but later in the loop the gradients are wrong. I have tried to replace grads[ws] with ForwardDiff.gradient; that works fine.

I have also tried to use pullback within the loop (the code used in the VAE example), which works fine. However, it becomes excruciatingly slow.

	 loss2, back = Flux.pullback(Flux.params(ws)) do
             variational_objective(log_density_, ws, 5000, 2)
         end
	 grad = back(1f0)
	 Flux.Optimise.update!(opt, ws, grad[ws])

I am very new to Flux and Zygote. So it is quite likely that I have done something silly in the replication. Any help or suggestion will be really appreciated.

To make the code complete, the following code was used to plot the target density and approximating density. The converging parameters are pretty off.

begin
	x₁ = range(-2, stop=2, length=151)
	x₂ = range(-4, stop=2, length=151)
	μ_, logσ_=unpack_params(ws, 2)
	plot(x₁,x₂, (x, y) -> exp(log_density([x, y])[1]), st=:contour)
	plot!(x₁,x₂, (x, y) -> pdf(MvNormal(μ_, Diagonal(exp.(2 .*logσ_))), [x,y]), st=:contour)
end

Best regards,
Lei

1 Like

Note that in contrast to autograd.grad() which returns a function that represents the gradient, Zygote.gradient(l() -> loss(), θ) returns the gradient evaluated at θ (i.e. a concrete number/array of numbers)!

So what’s wrong here, is that you’re not updating your gradient. You just calculate the gradient once in the beginning and use the same gradient all the time when calling Flux.Optimise.update!
Instead, you have to move the gradient calculation inside your training loop.

for t in 1:2000
    grads = gradient(() -> loss(), θ)
    Flux.Optimise.update!(opt, ws, grads[ws])
end

How slow? Do you have any benchmarks?

1 Like

That makes a lot of sense! Thanks a million.

I have changed the code accordingly (also wrap the code in a function); now it converges to the correct approximating distribution. But For 2000 iterations, it takes around 300s to finish; comparing with 1.2 s with ForwardDiff.gradient (which returns the gradient as a function).

Here is the benchmark result (I have set the iteration to 3; otherwise it might take forever):

Differentiation API · ForwardDiff returns the gradient value and not a function, were you perhaps thinking of a something else?

That aside, if ForwardDiff works for you and runs quickly, there’s no reason to stop using it in favour of Zygote. Flux.Optimise.update! doesn’t care where the gradients come from, as long as they match the shape of the parameters.

Thanks for your reply. Mmm… I have basically created a gradient function by using ForwardDiff; I assume ReverseDiff would work as well. i.e. something like

loss = (x) -> variational_objective(log_density, x, 5000, 2)
gradfun = (x) -> ForwardDiff.gradient(loss, x)

for t in 1:2000
    Flux.Optimise.update!(opt, ws, gradfun(ws))
end

I am not sure whether Zygote can be used in a similar fashion?

Yes. I can choose a specific automatic differentiation package to use. But ultimately, I would like to do variational inference on some reasonably complicated models and I want something more uniform; and it seems Flux/Zygote have better constructs/APIs.

I am not sure which part of the variational_objective makes Zygote slow. I am going to try auto-differentiation on log_density only and add back the Monte Carlo gradient estimation afterwards to see whether the problem lies in the Gaussian sampling step.

gradfun(x) = Zygote.gradient(loss, x) :wink:

It’s unclear what the bottlenecks might be for Zygote, but here is one candidate. Slicing of arrays in unpack_params and log_density). Since x is never used as a whole, μ and logσ could be passed in as separate parameters.

function log_density(mu, log_sigma)
	sigma_density = logpdf.(Normal(0, 1.35), log_sigma)
	mu_density = logpdf.(Normal.(0, exp.(log_sigma)), mu)
	return sigma_density .+ mu_density
end

function variational_objective(logprob, μ, logσ, nsample)
	sample_μs = randn(nsample) .* exp.(logσ) .+ μ
	sample_logσs = randn(nsample) .* exp.(logσ) .+ μ
	lower_bound = gaussian_entropy(logσ) + mean(logprob(sample_μs, sample_logσs))
	return -lower_bound
end

loss(μ, logσ) = variational_objective(log_density, μ, logσ, 5000)
gradfun(μ, logσ) = Zygote.gradient(loss, μ, logσ)

# split ws into μ, logσ
for t in 1:2000
    ∇μ, ∇logσ = gradfun(μ, logσ)
    Flux.Optimise.update!(opt, μ, ∇μ)
    Flux.Optimise.update!(opt, logσ, ∇logσ)
end
1 Like

ahhh, thanks a lot! Should have tried that.

I have tried your method. It seems removing the packing/unpacking improves efficiency a bit; but it still takes around 280-290s.

I have tried doing Monte Carlo gradient estimation manually. That makes it even worse … now it takes forever. I think the problem is simply because reverse differentiation takes longer when the number of parameters is small.

Edit: this is slower than the original, see post below for a better solution.

What do the per-iteration benchmarks look like now? I wonder if most of those 290s is actually just AD/compiler overhead on the first iteration.

Another thing to try would be using one big scalar function for log_density, just like you’ve done for gaussian_entropy:

function log_density(mu, log_sigma)
	sigma_density = logpdf.(Normal(0, 1.35), log_sigma)
	mu_density = logpdf.(Normal.(0, exp.(log_sigma)), mu)
	return sigma_density .+ mu_density
end

# ... mean(logprob(sample_μs, sample_logσs))

Would become

function log_density(mu, log_sigma)
	sigma_density = logpdf(Normal(0, 1.35), log_sigma)
	mu_density = logpdf(Normal(0, exp(log_sigma)), mu)
	return sigma_density + mu_density
end

# ... mean(logprob.(sample_μs, sample_logσs))

This will run a single broadcast and (hopefully) allocate one output instead of multiple intermediate values.

Here’s a version that takes advantage of the optimized routines in Distributions.jl and https://github.com/TuringLang/DistributionsAD.jl.

using Zygote, Distributions, DistributionsAD, LinearAlgebra, FillArrays
using BenchmarkTools

function gaussian_entropy(log_std)
	dim = length(log_std)
	return 0.5 * dim * log1p(2pi) + sum(log_std)
end

function unpack_params(params, dim)
  μ = params[1:dim]
  logσ = params[(dim+1):end]
	return μ, logσ
end

function log_density_old(x)
	mu = x[1,:]
	log_sigma = x[2,:]
	sigma_density = logpdf.(Normal(0, 1.35), log_sigma)
	mu_density = logpdf.(Normal.(0, exp.(log_sigma)), mu)
	return sigma_density .+ mu_density
end

function variational_objective_old(logprob, μ, logσ, nsample)
	samples = randn(2, nsample) .* exp.(logσ) .+ μ
	lower_bound = gaussian_entropy(logσ) + mean(logprob(samples))
	return -lower_bound
end

function log_density_new(x)
	mu = x[1,:]
	log_sigma = x[2,:]
  sigma_density = logpdf(MvNormal(Diagonal(Fill(1.35^2, length(mu)))), log_sigma)
  mu_density = logpdf(MvNormal(Diagonal(abs2.(exp.(log_sigma)))), mu)
	return sigma_density .+ mu_density
end

function variational_objective_new(logprob, μ, logσ, nsample)
  samples = rand(MvNormal(μ, Diagonal(abs2.(exp.(logσ)))), nsample)
	lower_bound = gaussian_entropy(logσ) + mean(logprob(samples))
	return -lower_bound
end

ws = [-1., -1.0, -5.0, -5.0]
mc_ = 5000
dim_ = 2
μ, logσ = unpack_params(ws, dim_)
loss_old(μ, logσ) = variational_objective_old(log_density_old, μ, logσ, mc_)
loss_new(μ, logσ) = variational_objective_new(log_density_new, μ, logσ, mc_)

And the results:

julia> @benchmark gradient(loss_old, $μ, $logσ)
BenchmarkTools.Trial: 34 samples with 1 evaluation.
 Range (min … max):  144.132 ms … 156.172 ms  ┊ GC (min … max): 7.08% … 11.76%
 Time  (median):     149.142 ms               ┊ GC (median):    9.07%
 Time  (mean ± σ):   149.351 ms ±   3.406 ms  ┊ GC (mean ± σ):  8.82% ±  1.61%

   █           ▃         ██ ▃ ▃       ▃     ▃          ▃         
  ▇█▇▁▇▁▁▁▇▁▁▁▁█▇▁▇▁▁▁▁▁▁██▇█▁█▁▁▁▁▁▁▇█▇▁▁▇▁█▁▁▁▇▁▁▁▁▁▁█▁▁▁▁▇▁▇ ▁
  144 ms           Histogram: frequency by time          156 ms <

 Memory estimate: 50.87 MiB, allocs estimate: 1319987.

julia> @benchmark gradient(loss_new, $μ, $logσ)
BenchmarkTools.Trial: 7053 samples with 1 evaluation.
 Range (min … max):  553.498 μs …   3.567 ms  ┊ GC (min … max):  0.00% … 66.38%
 Time  (median):     586.556 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   703.969 μs ± 415.897 μs  ┊ GC (mean ± σ):  12.82% ± 15.94%

  █▇▅▃▂           ▁▁                                  ▁▁▁       ▁
  █████████▇▇▆▆▅▃▇██▆▄▃▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄████▇▇▆▆▆▆ █
  553 μs        Histogram: log(frequency) by time       2.58 ms <

 Memory estimate: 2.24 MiB, allocs estimate: 547.

The difference is, unexpectedly, quite drastic! More could probably be done about the allocations (I haven’t tested with StaticArrays, for example), but this is probably closer to what you’re seeing with ForwardDiff.

1 Like

This is so interesting to know. I definitely have learnt a lot. Thanks a lot!

It seems the two log_density functions are not the same. MvNormal will return the total log likelihood. I have changed the variational_objective_new accordingly (i.e. divide nsamples). It is probably worthwhile to note using joint log likelihood can speed up a lot; it does make the code harder to understand.