Best/simplest way to calculate sparse Hessian

I’ve got a log-probability density function of multiple variables, and would like to integrate it using a Laplace approximation. To do this I need to find the function’s maximum and calculate the Hessian at that point.

That’s easy enough using Optim and AD, but when the Hessian is sparse, I’d like to be able to automatically detect the sparsity and exploit it. That should be possible using e.g. SparseDiffTools.jl and/or ModelingToolkit.jl, but from the documentation I’m not clear what the optimal approach is. The method below works, but seems hacky. Is there a better way?

using Distributions, ForwardDiff
d = MvNormal([1, 2], [1 0.2; 0.2 1])
f(x) = logpdf(d, x[1:2]) + logpdf(Normal(), x[3])
x0 = [1.0, 2.0, 0.0]
f(x0)
ForwardDiff.hessian(f, x0)


using SparseDiffTools, LinearAlgebra, SparseArrays
Hv = HesVec(f, x0)
H = hcat([sparse(Hv * i) for i in eachcol(I(3))]...)
1 Like

Giving this a bump…anyone have any advice? I’m going to use this Hessian and Laplace-approximation integral inside a larger optimization problem, so I’d like to get it as efficient as possible.

@ChrisRackauckas

We almost have it, but we don’t. We have automated Hessian sparsity detection in SparsityDetection.jl and Symbolics.jl. We have acyclic coloring of Hessian sparsities. But we don’t have the Hessian differentiation piece, which uses the former two and implements the symmetric compression/decompression. That will go into SparseDiffTools.jl, and I believe there’s already an issue tracking its development.

That said, the better way is to use Symbolics.jl to trace it and generate an analytical Hessian. If this is viable, it’ll be a lot faster even than sparse diff.

2 Likes

It looks like Symbolics.jl will work. I modified the tutorial here to get the following code:

using Symbolics
@variables u[1:3]
u_logpdf = simplify.(f(u))

fastf = eval(build_function(u_logpdf, u))
hess = Symbolics.sparsehessian(u_logpdf, vec(u))
fasthess = eval(build_function(hess, u)[2])

Benchmarking against my hack:

using BenchmarkTools
@btime hcat([sparse($Hv * i) for i in eachcol(I(3))]...)
10.400 μs (135 allocations: 10.67 KiB)
@btime fasthess($hess, $x0)
24.523 ns (1 allocation: 48 bytes)

I’d call 424x a respectable speedup :smiley: (although maybe overly optimistic in this MWE, since the Hessian is constant and Symbolics.jl is just returning a look-up table). Still extremely impressive, thanks for the great tool!

2 Likes