Fastest way to sample from MVN, changing parameters

Hey guys!
Does anybody know what is the most efficient way of sampling from a multivariate normal distribution whose parameters are constantly updated. I’m constructing a Gibbs sampler for a MVN whose parameters depend on matrix multiplication and inversion, thus I want to make things as efficient as possible.

For now, I’ve considered two approaches.

  1. Simply, per iteration, declare a new multivariate normal distribution using the distributions package. i.e. declaring a new variable per iteration.
  2. Computing the cholesky decomposition of my covariance matrix, and then sample from a standard normal of appropiate dimensions. Regardless, this seems to have much a higher cost than it is worht.

I would really value any insights.

Hi and welcome to the Julia Discourse!

Wouldn’t 1. compute something like a Cholesky decomposition down the hood anyway? I’d guess 2. can be marginally quicker, but it should be easy to benchmark this (edit: no, see below).

If you really want to cut off a bit more, you can probably copy what Distributions.jl does down the hood and avoid allocating intermediate things, but I’m not sure if the speed-up would be worth the time.

Looks like using the package is more efficient:

using Distributions, BenchmarkTools, Random, LinearAlgebra
Random.seed!(42)

function with_intermediate(μΣ, N, M, d)
    sum_out = 0.
    for i in 1:M
        μ = rand(d)
        Σ = rand(μΣ)
        X = rand(MvNormal(μ, Σ), N)
        sum_out += X[1,1]
    end
    return sum_out
end

function without_intermediate(μΣ, N, M, d)
    sum_out = 0.
    for i in 1:M
        μ = rand(d)
        S = cholesky(rand(μΣ)).L
        X = muladd(S, randn(d,N), μ)
        sum_out += X[1,1]
    end
    return sum_out
end


function test(d, N, M)
    S = cholesky(I(d))
    μΣ = Wishart(d, S)
    with_intermediate(μΣ, N, M, d)
    without_intermediate(μΣ, N, M, d)
    display(@benchmark with_intermediate($μΣ, $N, $M, $d))
    display(@benchmark without_intermediate($μΣ, $N, $M, $d))
end
test(4, 100, 10)

gives

BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  14.792 μs …  53.984 ms  ┊ GC (min … max):  0.00% … 99.91%
 Time  (median):     17.541 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   23.199 μs ± 539.668 μs  ┊ GC (mean ± σ):  23.25% ±  1.00%

               ▃▇▅█▆▄▁                                          
  ▁▁▁▂▂▃▄▃▃▃▃▄▆████████▆▅▄▃▃▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  14.8 μs         Histogram: frequency by time         24.6 μs <

 Memory estimate: 40.47 KiB, allocs estimate: 60.
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  17.791 μs …  57.846 ms  ┊ GC (min … max):  0.00% … 99.90%
 Time  (median):     24.458 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   33.310 μs ± 586.611 μs  ┊ GC (mean ± σ):  26.21% ±  3.13%

                 ▂▄▆▅▇██▇▆▂                                     
  ▁▂▂▂▂▅▆▅▄▃▃▂▃▅▇███████████▇▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  17.8 μs         Histogram: frequency by time         37.6 μs <

 Memory estimate: 109.84 KiB, allocs estimate: 90.