Autodiff a partial cholesky decomposition

I’m trying to write a reversible autodiff program with NiLang. It’s the partial cholesky decomposition. Written as a regular function it is. I’m open to writing this in another autodiff framework too.

function partial_cholesky(A::Matrix{Float64})
  n = size(A, 1)
  dA = diag(A)
  tol = minimum(dA[findall(dA .> 0)]) * 1e-9; 
  L = zeros(n, n) 
  r = 0; 
  
  for k = 1:n
    r += 1
    if r == 1
        L[k:n, r] = A[k:n, k]
    else 
        L[k:n, r] = A[k:n, k] - ( L[k:n, 1:(r-1)] * L[k, 1:(r-1)] )
    end

    if L[k, r] > tol 
      L[k, r] = sqrt(L[k,r]) 
      if k < n
         L[(k+1):n, r] = L[(k+1):n, r] / L[k, r] 
        end
    else 
      r = r - 1 
    end
    end

    return L[:, 1:r] 
end

An example of it’s use is

n1 = 3; n2= 5;
rank = 2
u = [rand(n1, rank); randn(n2, rank)];
A = u * u';

cholesky(A) # this will fail
partial_cholesky(A) # has number of columns == rank

When I attempt to write this in NiLang I’m getting a “not deallocated correctly” error. So I tried adding ~begin but I’m not quite sure I understand where to go from here. The attempted code is

using NiLang, LinearAlgebra

@i function partial_cholesky_reverse(L, A::Matrix{T}) where T
  # n = size(A, 1)
  # dA = diag(A)
  # tol = minimum(dA[findall(dA .> 0)]) * 1e-9; 
  r ← 0
    Laux ← zero(T, size(n - k + 1, 1))
    Lk ← zero(T, size(1, r - 2))
    Lsqrt ← zero(T, 1)
    Laux2 ← zero(T, size(n - k, 1))
  
  for k = 1:n
    INC(r)

    if (r == 1, ~)
         #  L[k:n, r] = A[k:n, k]
        Laux += A[k:n, k]
    else 
        # L[k:n, r] = A[k:n, k] - ( L[k:n, 1:(r-1)] * L[k, 1:(r-1)] )
        Laux += A[k:n, k] - L[k:n, 1:(r-1)] * L[k, 1:(r-1)] 
    end
   #L[k:n, r] += Laux
    # if L[k, r] > tol 
    if Laux[1, 1] > 1e-9
      Lsqrt += sqrt(Laux[1, 1]) 
      L[k, r] += Lsqrt 
      if k < n
         # L[(k+1):n, r] = L[(k+1):n, r] / L[k, r] 
         Laux2 += L[(k + 1):n, r] / Lsqrt 
         L[(k + 1):n, r] += Laux2
        end
    else 
      DEC(r)
    end

    L[k + 1:n, r] += Laux

    ~begin
        if (r == 1, ~)
         #  L[k:n, r] = A[k:n, k]
            Laux += A[k:n, k]
        else 
        # L[k:n, r] = A[k:n, k] - ( L[k:n, 1:(r-1)] * L[k, 1:(r-1)] )
            Laux += A[k:n, k] - L[k:n, 1:(r-1)] * L[k, 1:(r-1)] 
        end

        # if L[k, r] > tol 
        if L[k, r] > 1e-9
            Lsqrt += sqrt(L[k, r]) 
        end
        if k < n
            # L[(k+1):n, r] = L[(k+1):n, r] / L[k, r] 
            Laux2 += L[(k + 1):n, r] / Lsqrt 
        end
    end

    end
end

Solved on Slack. Posting here in case it helps anyone in the future (using Julia 1.6)

using NiLang, LinearAlgebra

@i function i_partial_cholesky!(L::Matrix{T}, A::Matrix{T}, r::Int, cache::Array{T,3}, branch_keeper::AbstractVector{Bool}) where T
    @invcheckoff @routine begin
        n ← size(A, 1)
        dA ← NiLang.value.(diag(A))
        tol ← minimum(dA[findall(dA .> 0)]) * 1e-9; 
    end
  
    @invcheckoff @inbounds for k = 1:n
        r += 1
        # change cache
        if r == 1
            for i=k:n
                cache[i, r, k] += A[i, k]
            end
        else 
            for i=k:n
                if i==k
                    for j=1:r-1
                        cache[i,r,k] -= L[k,j] ^ 2  # to avoid shared read
                    end
                else
                    for j=1:r-1
                        cache[i,r,k] -= L[i,j] * L[k,j]
                    end
                end
                cache[i,r,k] += A[i, k]
            end
        end

        # change L
        if (cache[k,r,k] > tol, branch_keeper[k])
            L[k, r] += sqrt(cache[k,r,k]) 
            for i=k+1:n
                L[i, r] += cache[i, r, k] / L[k, r] 
            end
            branch_keeper[k] ⊻= true
        else 
            r -= 1
        end
    end
    ~@routine
end


using Test
x = triu(randn(10, 10))
x = x * x'

@test NiLang.check_inv(i_partial_cholesky!, (zero(x), x, 0, zeros(10, 10, 10), zeros(Bool, 10)))
c1 = partial_cholesky(x)
c2, x, r, cache, branch_keeper = i_partial_cholesky!(zero(x), x, 0, zeros(10, 10, 10), zeros(Bool, 10))
@test c1 ≈ c2[:,1:r]
gc2, gx, gr, gcache = NiLang.AD.grad.((~i_partial_cholesky!)(NiLang.AD.GVar(c2, randn(10, 10)), NiLang.AD.GVar.((x, r, cache, branch_keeper))...))

1 Like