Hey guys!
Does anybody know what is the most efficient way of sampling from a multivariate normal distribution whose parameters are constantly updated. I’m constructing a Gibbs sampler for a MVN whose parameters depend on matrix multiplication and inversion, thus I want to make things as efficient as possible.
For now, I’ve considered two approaches.
- Simply, per iteration, declare a new multivariate normal distribution using the distributions package. i.e. declaring a new variable per iteration.
- Computing the cholesky decomposition of my covariance matrix, and then sample from a standard normal of appropiate dimensions. Regardless, this seems to have much a higher cost than it is worht.
I would really value any insights.
Hi and welcome to the Julia Discourse!
Wouldn’t 1. compute something like a Cholesky decomposition down the hood anyway? I’d guess 2. can be marginally quicker, but it should be easy to benchmark this (edit: no, see below).
If you really want to cut off a bit more, you can probably copy what Distributions.jl does down the hood and avoid allocating intermediate things, but I’m not sure if the speed-up would be worth the time.
Looks like using the package is more efficient:
using Distributions, BenchmarkTools, Random, LinearAlgebra
Random.seed!(42)
function with_intermediate(μΣ, N, M, d)
sum_out = 0.
for i in 1:M
μ = rand(d)
Σ = rand(μΣ)
X = rand(MvNormal(μ, Σ), N)
sum_out += X[1,1]
end
return sum_out
end
function without_intermediate(μΣ, N, M, d)
sum_out = 0.
for i in 1:M
μ = rand(d)
S = cholesky(rand(μΣ)).L
X = muladd(S, randn(d,N), μ)
sum_out += X[1,1]
end
return sum_out
end
function test(d, N, M)
S = cholesky(I(d))
μΣ = Wishart(d, S)
with_intermediate(μΣ, N, M, d)
without_intermediate(μΣ, N, M, d)
display(@benchmark with_intermediate($μΣ, $N, $M, $d))
display(@benchmark without_intermediate($μΣ, $N, $M, $d))
end
test(4, 100, 10)
gives
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 14.792 μs … 53.984 ms ┊ GC (min … max): 0.00% … 99.91%
Time (median): 17.541 μs ┊ GC (median): 0.00%
Time (mean ± σ): 23.199 μs ± 539.668 μs ┊ GC (mean ± σ): 23.25% ± 1.00%
▃▇▅█▆▄▁
▁▁▁▂▂▃▄▃▃▃▃▄▆████████▆▅▄▃▃▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
14.8 μs Histogram: frequency by time 24.6 μs <
Memory estimate: 40.47 KiB, allocs estimate: 60.
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 17.791 μs … 57.846 ms ┊ GC (min … max): 0.00% … 99.90%
Time (median): 24.458 μs ┊ GC (median): 0.00%
Time (mean ± σ): 33.310 μs ± 586.611 μs ┊ GC (mean ± σ): 26.21% ± 3.13%
▂▄▆▅▇██▇▆▂
▁▂▂▂▂▅▆▅▄▃▃▂▃▅▇███████████▇▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
17.8 μs Histogram: frequency by time 37.6 μs <
Memory estimate: 109.84 KiB, allocs estimate: 90.