How to implement a Gaussian NN with covariance matrix?


I’m editing this question to make it clearer since I got no solution still :frowning:.

I need to implement a NN that outputs a covariance matrix. Say we have N variables, I need to be able to compute a symetric NxN positive semi-definite matrix from the output of the NN.

The authors of the algorithm I’m trying to implement use a vector output of length Nx(N+1)/2.
When they sample, this vector is transformed into a lower triangular matrix and multiplied by its transpose to obtain a symetric NxN positive semi-definite matrix, treated as the covariance matrix.
That is the output vector is the elements of the cholesky decomposition of the covariance matrix.
I found a pytorch example of this, here’s the relevant part:

#da is what I call N
cholesky_vector = self.cholesky_layer(x) #Vector of length Nx(N+1)/2
cholesky_diag_index = torch.arange(da, dtype=torch.long) + 1 # julia equivalent is [1:da;]
cholesky_diag_index = (cholesky_diag_index * (cholesky_diag_index + 1)) // 2 - 1 # computes the indices of the future diagonal elements of the matrix
cholesky_vector[:, cholesky_diag_index] = F.softplus(cholesky_vector[:, cholesky_diag_index]) #softplus projects the diagonal to >0.
tril_indices = torch.tril_indices(row=da, col=da, offset=0) # Collection that contains the indices of the non-zero elements of a lower triangular matrix
cholesky = torch.zeros(size=(B, da, da), dtype=torch.float32).to(device) #initialize a square matrix to zeros
cholesky[:, tril_indices[0], tril_indices[1]] = cholesky_vector # Assigns the elements of the vector to their correct position in the lower triangular matrix

where da is the number of dimensions of the distribution, what I call N (we deal with a da x da matrix that is)
Here’s how I tried to reimplement this in julia, accounting for the fact that we are now in a column major setting (I need to make an upper triangular matrix instead of lower) and with one-based arrays. I took an example for a matrix with da = N = 4

using Flux, LinearAlgebra
N = 4
cholesky_vector = rand(N*(N+1)÷2 ,1) # output of the vector

function sigma(v)
    cholesky_diag_index = [1:N;] .* [2:(N+1);] ÷ 2 #indices of the elements of cholesky_vector that constitute the diagonal of the UT matrix
    softplus(x) = log(1+exp(x))

    cholesky_vector[cholesky_diag_index] .= softplus.(cholesky_vector[cholesky_diag_index]) # make sure the standard deviations are > 0
    triu_indices = [CartesianIndex(i,j) for i in 1:N, j in 1:N if i <= j] #indices of the lower triangular part of a matrix

    chol = zeros(N,N)
    for (lidx,cidx) in enumerate(tril_indices)
        chol[cidx] = cholesky_vector[lidx]
    Sigma = chol*chol' #reconstruction of covariance matrix
    sum(Sigma) # suming elements to get a scalar function, just to see if the gradient works.

g = gradient(sigma, cholesky_vector)

And I get an error telling me that mutating arrays are not supported due to the setindex!. I don’t really know how to deal with this, pytorch seems to handle it. Anyone knows a trick to do this ? :slight_smile:

Thanks a lot

Could one alternative be to let the output of the network be a matrix (instead of a vector) and then use tril to get the lower triangular part to construct a covariance matrix?

Yes, it is probably what I’ll have to do. It is a bit less efficient because there are useless output neurons for the upper part, but at least with tril their gradient will be 0 so no BP.
I was surprised that pytorch can handle this but not Zygote so I was hoping I was missing something.

Do you know whether PyTorch uses the index approach? I would not be surprised if the matrix approach turns out to be faster. The code is also simpler.

The python code in my post is not from me, and honestly I don’t even know how performant that implementation is. It’s just the only one I found, maybe I should look for other ones. In the publication that presents the algorithm, the authors specifically say that their NN outputs the Cholesky vector but it is totally possible that they did not consider the full matrix approach.

You could probably generate the matrix without mutation (this is why Zygote is unhappy) by using some matrix comprehension or similar. I didn’t really take the time to understand what was going on in your code, so this does not do the same thing, but is just a quick sketch of how it could maybe be done. No clue if this is better than the other suggestion though.

using Flux

N = 4
v = rand(N*(N+1)÷2, 1) 

function sigma(v)
    softplus(x) = log(1+exp(x))

    c2idx(i, j) = (i*(i-1))÷2+j

    f(i, j) = if i == j
        softplus(v[c2idx(i, j)]) # Do something special with diagonal
    elseif i < j
        v[c2idx(i, j)] # Do nothing with normal 

    tmp = [f(i, j) for i in 1:N, j in 1:N] # Generate whole matrix at once to avoid mutation

    Sigma = tmp * tmp'

g = gradient(sigma, v)
1 Like

Yeah this could work ! I had thought about doing a comprehension, but I stumbled on another problem.
I’ll try this to see.
Thanks a lot.

That’s probably the way to go. Just wanted to chime in to correct a minor mistake in c2idx:
Looking at the vector indices this translates to:

julia> c2idx(i, j) = (i*(i-1))÷2+j
c2idx (generic function with 1 method)

julia> [ifelse(j<i, 0, c2idx(i,j)) for i in 1:N, j in 1:N]
4×4 Matrix{Int64}:
 1  2  3   4
 0  3  4   5
 0  0  6   7
 0  0  0  10

i think you missed the sidelength in the formula:

julia> new_c2idx(i, j) = ((2N-i)*(i-1))÷2+j
new_c2idx (generic function with 1 method)

julia> [ifelse(j<i, 0, new_c2idx(i,j)) for i in 1:N, j in 1:N]
4×4 Matrix{Int64}:
 1  2  3   4
 0  5  6   7
 0  0  8   9
 0  0  0  10