Type stable matrix factorization corner case with ForwardDiff

I am calculating the log likelihood of multivariate normals with a function like below, and I want it to work smoothly with ForwardDiff.jl. The catch is that when the variance is not psd (eg with zero eigenvalues), I want to return a -Inf, but when applicable, wrapped in the appropriate Dual, complete with tags and partials (which can be nonsense, they will be ignored).

"""
Sum of the log pdf for observations ``xᵢ ∼ MultivariateNormal(0, Σ)``,
where ``xᵢ`` are stored in the rows of `X`.
"""
function logpdf_normal(X, Σ)
    n, m = size(X)
    U = chol(Σ)
    A = X / U
    -0.5 * (n*(m*log(2*π) + 2*logdet(U)) + sum(abs2, A))
end

The problem is if course that chol(Σ) fails when Σ is not psd. Which is fine, I could try ... catch, but then I am not sure how to build up the appropriate type for Dual. Anything could be a dual, X, Σ, the actual problem has more terms (eg non-zero mean, or a residual derived from a regression) which may or may not be Duals.

I could write my own version of chol which is less picky about psd. Apparently LDL' is not implemented for dense matrices in Base.

1 Like

Add 1e-14 to the diagonal? In logscale, 14 should be infinity enough :slight_smile:

I would like to avoid this (for reasons that are not simple to display in an MWE). However, I found my quick and dirty solution:

"""
Sum of the log pdf for observations ``xᵢ ∼ MultivariateNormal(0, Σ)``,
where ``xᵢ`` are stored in the rows of `X`.
"""
function logpdf_normal(X::AbstractMatrix{TX},
                       Σ::AbstractMatrix{TΣ}) where {TX, TΣ}
    n, m = size(X)
    try
        U = chol(Σ)
        A = X / U
        -0.5 * (n*(m*log(2*π) + 2*logdet(U)) + sum(abs2, A))
    catch
        convert(promote_type(TX, TΣ), -Inf)
    end
end

which is type stable and does what I want, even with ForwardDiff.Dual.

Although I am wondering whether I should care so much about type stability everywhere, with upcoming changes in v0.7. When #23338 and similar are fixed, I can use small unions and just be OK with the fact that the return type is Union{Float64, Dual{:some_tag, Float64, ...}} and the compiler can handle it anyway without a significant performance penalty.

These type calculations for values that are not calculated remind me of similar troubles with Nullable, which also got superseded by Unions. (If I ever write my memoirs, the title shall be “How I learned to stopped worrying about type stability and love Unions”).

1 Like

Depending on the frequency of non-positive definite matrices and their size, it may pay off to use an if-else instead of a try catch:

julia> function safe_chol!(U::AbstractArray{<:Real,2}, Σ::AbstractArray{<:Real,2})
           p = size(Σ,1)
           @assert p == size(Σ,2) == size(U,1) == size(U,2)
           @inbounds for i ∈ 1:p
               Uᵢᵢ = Σ[i,i]
               for j ∈ 1:i-1
                   Uⱼᵢ = Σ[j,i]
                   for k ∈ 1:j-1
                       Uⱼᵢ -= U[k,i] * U[k,j]
                   end
                   Uⱼᵢ /= U[j,j]
                   U[j,i] = Uⱼᵢ
                   Uᵢᵢ -= abs2(Uⱼᵢ)
               end
               Uᵢᵢ > 0 ? U[i,i] = √Uᵢᵢ : return false
           end
           true
       end
safe_chol! (generic function with 1 method)

julia> function logdettri(tri::AbstractMatrix{T}) where T
           p = size(tri,1)
           @assert p == size(tri,2)
           out = zero(T) #if you trust in numerical stability, you could swap comments
       #   out = one(tri)    
           @inbounds for i ∈ 1:p
               out += log(tri[i,i])
       #       out *= tri[i,i]        
           end
           out
       #    log(out)
       end
logdettri (generic function with 1 method)

julia> function logpdf_normal_try(X::AbstractMatrix{TX},
           Σ::AbstractMatrix{TΣ}) where {TX, TΣ}
           n, m = size(X)
           try
               U = chol(Σ)
               A = X / U
               -0.5 * (n*(m*log(2*π) + 2*logdet(U)) + sum(abs2, A))
           catch
               convert(promote_type(TX, TΣ), -Inf)
           end
       end
logpdf_normal_try (generic function with 1 method)

julia> function logpdf_normal_if(X::AbstractMatrix{TX},
           Σ::AbstractMatrix{TΣ}) where {TX, TΣ}
           n, m = size(X)
           U = Matrix{promote_type(TX, TΣ)}(m,m)
           density = if safe_chol!(U, Σ)
               A = X / UpperTriangular(U)
               -0.5 * (n*(m*log(2*π) + 2*logdettri(U)) + sum(abs2, A))
           else
               convert(promote_type(TX, TΣ), -Inf)
           end
           density
       end
logpdf_normal_if (generic function with 1 method)

julia> PD = randn(50,40) |> x -> x' * x ;   

julia> PSD = randn(35,40) |> x -> x' * x ;

julia> X = randn(100,40) * chol(PD) ;

julia> using BenchmarkTools

julia> @btime logpdf_normal_if($X, $PD)
  32.942 μs (3 allocations: 43.95 KiB)
-12346.835627794037

julia> @btime logpdf_normal_try($X, $PD)
  33.493 μs (8 allocations: 44.08 KiB)
-12346.835627794037

julia> @btime logpdf_normal_if($X, $PSD)
  6.518 μs (1 allocation: 12.63 KiB)
-Inf

julia> @btime logpdf_normal_try($X, $PSD)
  28.794 μs (6 allocations: 12.75 KiB)
-Inf

julia> PD = randn(15,10) |> x -> x' * x;

julia> PSD = randn(8,10) |> x -> x' * x;

julia> X = randn(100,10) * chol(PD);

julia> @btime logpdf_normal_if($X, $PD)
  11.622 μs (2 allocations: 8.81 KiB)
-2441.958129016739

julia> @btime logpdf_normal_try($X, $PD)
  12.524 μs (7 allocations: 8.94 KiB)
-2441.958129016739

julia> @btime logpdf_normal_if($X, $PSD)
  215.772 ns (1 allocation: 896 bytes)
-Inf

julia> @btime logpdf_normal_try($X, $PSD)
  22.612 μs (6 allocations: 1.00 KiB)
-Inf

julia> PD = randn(150,100) |> x -> x' * x;

julia> PSD = randn(80,100) |> x -> x' * x;

julia> X = randn(500,100) * chol(PD);

julia> @btime logpdf_normal_if($X, $PD)
  305.014 μs (4 allocations: 468.91 KiB)
-184577.00794031422

julia> @btime logpdf_normal_try($X, $PD)
  266.962 μs (9 allocations: 469.03 KiB)
-184577.00794031422

julia> @btime logpdf_normal_if($X, $PSD)
  67.166 μs (2 allocations: 78.20 KiB)
-Inf

julia> @btime logpdf_normal_try($X, $PSD)
  92.714 μs (7 allocations: 78.33 KiB)
-Inf

Of course, if you invest more time you’re likely to find a better solution.
Alternatively, a better Cholesky decomposition algorithm probably wont fall behind LAPACK so quickly (assuming your covariance matrix is free of dual numbers).

1 Like

Thanks for the very thorough analysis and benchmarks. In practice, only the exploratory phase of the algorithm encounters PSD matrices, then it should converge to a region where everything is PD almost surely (for a given value of “almost” :smile:).

Am I correct in assuming that most of the cost of the try ... catch mechanism comes from the exceptions, and that when they do not occur, the cost is small?

Neat trick for not cluttering the namespace.

1 Like

I was messing with very small P(S)D matrices months ago, so figured I’d share what I had then (also, nice to refresh on things you’ve done after learning more).

Comparing if-else, try-catch, and abs:

julia> log_abs_if(x) = x > 0 ? log(x) : log(-x)
log_abs_if (generic function with 1 method)

julia> function log_abs_try(x)
           try
               return log(x)
           catch
               return log(-x)
           end
       end
log_abs_try (generic function with 1 method)

julia> log_abs(x) = log(abs(x))
log_abs (generic function with 1 method)

julia> @btime log_abs_if(2.7)
  10.751 ns (0 allocations: 0 bytes)
0.9932517730102834

julia> @btime log_abs_try(2.7)
  21.042 ns (0 allocations: 0 bytes)
0.9932517730102834

julia> @btime log_abs(2.7)
  11.031 ns (0 allocations: 0 bytes)
0.9932517730102834

julia> @btime log_abs_if(-2.7)
  11.573 ns (0 allocations: 0 bytes)
0.9932517730102834

julia> @btime log_abs_try(-2.7)
  19.998 μs (0 allocations: 0 bytes)
0.9932517730102834

julia> @btime log_abs(-2.7)
  10.971 ns (0 allocations: 0 bytes)
0.9932517730102834

In this example try seems to cost about 10 nanoseconds, and catch nearly 20,000 (20 microseonds), which is a lot compared to how long it takes to factor a small PD matrix.

I would guess the LAPACK and ? : checking versions probably recognize an error at about the same percentage of the way through the problem.

Neat trick for not cluttering the namespace.

I could also do Symmetric(Base.LinAlg.BLAS.syrk('U', 'T', 1.0, randn(50,40))), but that’s not winning any cleanliness awards. (But there is the slight plus that alpha in syrk lets you choose variance of randn).

I remember seeing something about the pipe operator |> getting deprecated, to allow for possibly adding it back in as syntax in 1.x. I couldn’t find any updates on that with a brief search. No dep warning on my 9 day old master.

it should converge to a region where everything is PD almost surely

I don’t know your algorithm, but can you reparametrize a p x p PD matrix, first into its Cholesky decomposition, and then take the logarithm of the diagonal elements? You’d have a p*(p-1)/2 triangle of unbounded off diagonals, plus p unbounded log-diagonals. That is, to recover the PD matrix, you exponentiate the diagonals, and then calculate PD = U'U.
If your algorithm were to explore this space, it would be incapable of generating non-PD matrices [bar a diagonal so negative that exponentiating them still produces 0, but as long as your initial values aren’t close to -746 you should be fine].

This is how Stan handles PD matrices.

1 Like

In v0.7, cholfact will no longer throw for non-PD matrices. One can call issuccess() on the factorization in an if statement for these situations.

3 Likes

Unfortunately not; this is an indirect inference problem, and the variance matrix comes from simulated data. This is a corner case in a regression, for certain regions of the parameter space the wages are too low, no one works, so the intercept predicts employment perfectly. These points have 0 posterior density, but they need to be defined. Thanks for your benchmarks on the cost of catch/try, I can live with the catch if the try is occasional. Once the NUTS algorithm is tuned, I don’t run into these regions any more.

1 Like

Btw, just encountered it and it make me think of this comment: there is a bkfact, which should do what you want. Bkfact and ldltfact seem to be functionally equivalent, maybe they should be merged?

I think they preserve sparsity differently, but I am not sure.

In any case, as @Ralph_Smith suggested above, on v0.7 I have

julia> C = cholfact(ones(3,3))
Failed factorization of type Base.LinAlg.Cholesky{Float64,Array{Float64,2}}

julia> LinAlg.issuccess(C)
false

julia> C[:L]
3×3 LowerTriangular{Float64,Array{Float64,2}}:
 1.0   ⋅    ⋅ 
 1.0  0.0   ⋅ 
 1.0  1.0  1.0

so my problem is solved.