Gibbs sampler works in R doesn't in Julia (full Julia code provided) please help

Hi Everyone,

I’m a director of data science and I’m checking out Julia as a tool for my group to use. I chose an example from the Computer Age Statistical Inference on population mixtures with gene data. The particular example can be found on page 256 of the linked pdf (https://web.stanford.edu/~hastie/CASI_files/PDF/casi.pdf). The data is freely available and the Julia code I provide below should enable you to easily get it into Julia.

The problem is to find what populations the individuals come from without using the labels. So it’s an unsupervised learning problem (though the actual labels are there to see how you’re doing). There are 100 genes that are coded as 0, 1, 2. The value is determined by two phenotypes X1 and X2. If X1 and X2 are both 0 then the value is 0, both 1 then 2 and if only one of the Xs is 1 then the value is 1. The estimation procedure is described in detail in the book and involves Gibbs sampling. My test was to see how my unoptimized R code compares to my unoptimized Julia code.

I’m very comfortable with R and was able to basically get what they have in the book (first image below). I’ve done a translation to Julia and the output is quite bad. I’m not sure where the issue is and I’m hoping some you guys can help out. On the plus side, the time to run is ~30x faster in Julia.

Here’s the mixture I got from R:

My attempt in Julia results in:

using CSVFiles, DataFrames, DataFramesMeta, CategoricalArrays, StatsBase, Distributions, Random, FreqTables, BenchmarkTools
data = DataFrame(load("https://web.stanford.edu/~hastie/CASI_files/DATA/haplotype.csv"))

for col in names(data)[3:102]
    data[ismissing.(data[col]), col] = 0
end

data_snp = data[names(data)[3:102]]

Random.seed!(2342)

function prob_recode(x)
    bin = Binomial()
    local where_one, where_x, x1
    where_x = findall(y -> y == 1, x)
    where_two = findall(y -> y == 2, x)
    x1 = copy(x)
    x2 = copy(x)
    x1[where_x] = rand(bin, length(where_x))
    where_x1 = findall(y -> y == 1, x1)
    other_ones = setdiff(where_x, where_x1)
    x2[where_x1] = fill(0, length(where_x1))
    x2[other_ones] = fill(1, length(other_ones))
    
    x1[where_two] = fill(1, length(where_two))
    x2[where_two] = fill(1, length(where_two))
    
    x, x1, x2
end

X = colwise(prob_recode, data_snp);

col_names = names(data_snp)
a = hcat(X[1][:]...)[:, 2]
b = hcat(X[1][:]...)[:, 3]
for i in 2:100
    a = hcat(a, hcat(X[i][:]...)[:, 2])
    b = hcat(b, hcat(X[i][:]...)[:, 3])
end 
    
X1 = DataFrame(a)
X2 = DataFrame(b)
rename!(X1, [f => t for (f, t) = zip(names(X1), col_names)])
rename!(X2, [f => t for (f, t) = zip(names(X2), col_names)]);

N = size(data_snp)[1] # num. of obs
M = size(data_snp)[2]  # num. of variables
J = 3 # num. of latent populations

Q_prior = Dirichlet(ones(J)) 
Q = transpose(rand(Q_prior, N))

P_prior = Beta() # equivalent to Beta(1, 1)
P = hcat(rand(P_prior, N), rand(P_prior, N), rand(P_prior, N))
P = P ./ sum(P, dims = 2);

Z1 = Array{Float64}(undef, N, M)
Z2 = Array{Float64}(undef, N, M)

mult = Multinomial(1, Q[1, :])

z1 = rand(mult, M)
z2 = rand(mult, M)

for n in 1:N
    for i in 1:M
        Z1[n, i] = findall(x -> x == 1, z1[:, i])[1]
        Z2[n, i] = findall(x -> x == 1, z2[:, i])[1]
    end
end

Z = Z1, Z2

function Z_cond(Z1, Z2, X1, X2, P, Q)
    N = nrow(X1)
    M = size(Z1, 2)
    
    z1 = Vector{Float64}(undef, 3)
    z2 = Vector{Float64}(undef, 3)
   
    for n in 1:N
        for m in 1:M
            for j in 1:3
                z1[j] = Q[n, j] * (abs(X1[n, m] - 1) + P[m, j])
                z2[j] = Q[n, j] * (abs(X2[n, m] - 1) + P[m, j])
            end
            z1 = z1 ./ sum(z1)
            z2 = z2 ./ sum(z2)
            mult1 = Multinomial(1, z1)
            mult2 = Multinomial(1, z2)
            Z1[n, m] = findall(x -> x == 1, rand(mult1, 1))[1][1]
            Z2[n, m] = findall(x -> x == 1, rand(mult2, 1))[1][1]
        end
    end
    return Z1, Z2
end          

function P_cond(P, X1, X2, Z1, Z2)
    p = Vector{Float64}(undef, 3)
    for i in 1:size(P, 2) 
        n11 = freqtable(Z1[findall(x -> x == 1, X1[1]), i])            
        n01 = freqtable(Z1[findall(x -> x == 0, X1[1]), i])
        n12 = freqtable(Z2[findall(x -> x == 1, X2[1]), i])    
        n02 = freqtable(Z2[findall(x -> x == 0, X2[1]), i])
        
        n1 = Vector{Int64}(undef, 3)
        n0 =  Vector{Int64}(undef, 3)
        for j in 1:3
            n1[j] = get(n11, j, 0) + get(n12, j, 0)
            n0[j] = get(n01, j, 0) + get(n02, j, 0)
        end 
        
        p1 = Beta(1 + n1[1], 1 + n0[1])
        p2 = Beta(1 + n1[2], 1 + n0[2])
        p3 = Beta(1 + n1[3], 1 + n0[3])
        
        p[1] = rand(p1, 1)[1]
        p[2] = rand(p2, 1)[1]
        p[3] = rand(p3, 1)[1]
        p = p / sum(p)
    
        P[i, :] = p
    end
    return P
end

function Q_cond(Q, X1, X2, Z1, Z2)
    X1 = convert(Array, X1)
    X2 = convert(Array, X2)

    for i in 1:size(Q)[1]
        m = Vector{Int64}(undef, 3)
        m1 = freqtable(Z1[i, findall(x -> x == 1, X1[i, :])])
        m2 = freqtable(Z2[i, findall(x -> x == 1, X2[i, :])])
        for j in 1:3
           m[j] = get(m1, j, 0) + get(m2, j, 0)
        end    
        q = Dirichlet(1 .+ m) 
        Q[i, :] = transpose(rand(q, 1))
    end
   return Q 
end


function gibb(Z, P, Q, n)
    for i in 1:n
        Z = Z_cond(Z[1], Z[2], X1, X2, P, Q)
        P = P_cond(P, X1, X2, Z[1], Z[2])
        Q = Q_cond(Q, X1, X2, Z[1], Z[2])
    end
    return Z, P, Q
end


burnin = gibb(Z, P, Q, 1000)
function gibb_acc(Z, P, Q, n)
    Q_acc = copy(Q)
    
    for i in 1:n
        Z = Z_cond(Z[1], Z[2], X1, X2, P, Q)
        P = P_cond(P, X1, X2, Z[1], Z[2])
        Q = Q_cond(Q, X1, X2, Z[1], Z[2])
        Q_acc += Q
    end
    return Q_acc/n, Q
end
out = gibb_acc(burning[1], burning[2], burning[3], 500)

2 Likes

Can you share the R code? This looks super interesting and will like to spend some time playing with it. I’ll let you know if I find anything, but I’m not an expert in Julia (yet…).

3 Likes

There is a typo, burning instead of burnin, maybe out contains old stufff and not the simulation result? Sometimes it is a good idea to restart the REPL to be sure not to access any outdated variables.

Unfortunately, that is just a typo prior to pasting here. I had put “burning” and removed the “g” but not on everything. I did try restarting the REPL a few times.

Will post on Monday once I get back to work where the code is saved.

Not sure if this helps, but as a comment on the style

bin = Binomial()
where_x = findall(y -> y == 1, x)
x1[where_x] = rand(bin, length(where_x))
where_x1 = findall(y -> y == 1, x1)
x2[where_x1] = fill(0, length(where_x1))

Could be replaced by the more idiomatic:

x1[x .== 1] .= rand.(bin)
x2[x1 .== 1] .= 0

which makes IMO the code more readable (see here for an explanation of this syntax).

Similarly:

p1 = Beta(1 + n1[1], 1 + n0[1])
p2 = Beta(1 + n1[2], 1 + n0[2])
p3 = Beta(1 + n1[3], 1 + n0[3])
    
p[1] = rand(p1, 1)[1]
p[2] = rand(p2, 1)[1]
p[3] = rand(p3, 1)[1]

could be written as

p .= rand.(Beta.(1 .+ n1, 1 .+ n0))

or, if the too many dots look a bit ugly:

@. p = rand(Beta(1 + n1, 1 + n0))
1 Like

Thanks! I welcome any and all help with style or content on the code. This was my first attempt at coding in Julia and I kept thinking, “There has to be a better way to write this.”

I started rewriting this code in a more Julian style for fun, and here’s what I got for the first chunk:

snps = hcat(getfield(data, :columns)[3:end]...)

a = similar(snps, Bool)
b = similar(snps, Bool)

for (index, snp) in enumerate(snps)
    a[index], b[index] = 
        if snp === missing
            false, false
        elseif snp == 0
            false, false
        elseif snp == 1
            x = rand(Bool)
            x, !x
        elseif snp == 2
            true, true
        else
            error("All snps must be 0, 1, 2, or missing")
        end
end

I think? this is what you were trying to do with the first part? I stopped at the word Dirichlet cause I don’t know what it means.

can be written as

if

if coalesce(snp, 0) == 0
   false, false
2 Likes

I’m also playing with this code and I can confirm that what you did seems to do do the same as the OP’s. I replaced your code chunk, ran the OP’s code stating at X1 = DataFrame(a) and everything ran accordingly. That didn’t change the results, though. And Dirichlet is just a probability distribution that is used to model one of the variables for the Bayesian model.

The Dirichlet distribution is a distribution on probability vectors and here the prior distribution.

As requested the R code is below. The more I play with Julia I think there’s a huge issue with the scoping of the variables but I’m not sure how to fix that.

library(MCMCpack)
library(data.table)
library(softImpute)
library(progress)
library(compiler)
library(klaR)
library(ggtern)

haplotype <- fread("https://web.stanford.edu/~hastie/CASI_files/DATA/haplotype.csv")
names(haplotype)[1] <- "subject"

# J = 3 parent populations
# Qi = (q1, q2, q3) probability vector of each pop. for individual i
# M = 100 avariables
# Gim = geneotype at SNP m for individual i. A 3 level factor (0, 1, 2)
# Pj = unknown M-vector of allele population proportions for population j
# goal is to estimate Q and P

# generative model
# create a pair of variables Xim = (X1im, X2im) corresponding to which we allocate the two alleles. 
# for example, if Gim = 1 (corresponding to Aa), then we might set X1im = 0 and X2im = 1 (or vice versa).
# if Gim = 0 they are both 0, and if Gim = 2, they are both 1.
# Let Zim represent the ancestral origin for individual i of each of these allele copies.

G <- haplotype[, 3:102]
fit <- softImpute(as.matrix(G), rank.max = 99, lambda = 75)

G_com <- complete(G, fit)
G_com[45, 75] <- 0
G_com[197, 75] <- 0

# recode G into a pair of binary X matrices
ones <- which(G_com == 1, arr.ind = T)
X1_ones <- rbinom(ones, 1, 0.5)
X2_ones <- 1 - X1_ones

G <- G_com[, lapply(.SD, function(x) ifelse(x == 2, 1, x))]
X1 <- G
X2 <- G
for(i in 1:nrow(ones)){
  X1[ones[[i, 1]], ones[[i, 2]]] <- X1_ones[[i]]
  X2[ones[[i, 1]], ones[[i, 2]]] <- X2_ones[[i]]
}
X1 <- X1[, lapply(.SD, as.integer)]
X2 <- X2[, lapply(.SD, as.integer)]

N <- 197 # num. of obs
M <- 100 # num. of variables
J <- 3 # num. of latent populations

# priors
# not updating at each step, though we could
Q <- rdirichlet(N, c(1, 1, 1))
P <- cbind(rbeta(M, 1, 1), 
           rbeta(M, 1, 1), 
           rbeta(M, 1, 1))
P <- P/apply(P, 1, function(x) sum(x, na.rm = T))

# initialize Z
Z1 <- matrix(nr = N, nc = M)
Z2 <- matrix(nr = N, nc = M)
for (n in 1:N){
  z1 <- which(rmultinom(M, 1, Q[1, ]) == 1, arr.ind = T)
  z2 <- which(rmultinom(M, 1, Q[1, ]) == 1, arr.ind = T)
  Z1[n, ] <- z1[, 1]
  Z2[n, ] <- z2[, 1]
}

# probability that Zcim = j given X, P, Q for each copy c = 1, 2
# step 1
# need to do this for X1, X2 and for every row and column of X1/2  
# 
Z_cond <- function(Z1, Z2, X1, X2, P, Q){
  
  N <- nrow(X1)
  M <- ncol(Z1)
  
  z1 <- vector(mode = "integer", length = 3)
  z2 <- vector(mode = "integer", length = 3)
  
  X1 <- as.matrix(X1)
  X2 <- as.matrix(X2)
  
  for(n in 1:N){
    for(m in 1:M){
      for(j in 1:3){
        z1[j] <- Q[n, j] * dbinom(X1[n, m], 1, P[m, j][[1]])
        z2[j] <- Q[n, j] * dbinom(X2[n, m], 1, P[m, j][[1]])
      }
      Z1[n, m] <- which(rmultinom(1, 1, z1) == 1)
      Z2[n, m] <- which(rmultinom(1, 1, z2) == 1)
    }
  }
  return(list(Z1, Z2))
}

# step 2
# update P
P_cond <- function(P, X1, X2, Z1, Z2){

  if(any(class(X1) != "matrix") | any(class(X2) != "matrix")) {
    X1 <- as.matrix(X1)
    X2 <- as.matrix(X2)
  }
  
  p <- vector(mode = "integer", length = 3)
  
  for(i in 1:nrow(P)){
  
    n11 <- tabulate(Z1[which(X1[, i] == 1), i], nbins = 3)
    n12 <- tabulate(Z2[which(X2[, i] == 1), i], nbins = 3)
    
    n01 <- tabulate(Z1[which(X1[, i] == 0), i], nbins = 3)
    n02 <- tabulate(Z2[which(X2[, i] == 0), i], nbins = 3)
    
    n1 <- n11 + n12
    n0 <- n01 + n02
    
    p[1] <- rbeta(1, 1 + n1[1], 1 + n0[1])
    p[2] <- rbeta(1, 1 + n1[2], 1 + n0[2])
    p[3] <- rbeta(1, 1 + n1[3], 1 + n0[3])
    
    P[i, ] <- p
  }
  return(P/apply(P, 1, sum))
}
# update Q

Q_cond <- function(Q, X1, X2, Z1, Z2){
  
  if(any(class(X1) != "matrix") | any(class(X2) != "matrix")) {
    X1 <- as.matrix(X1)
    X2 <- as.matrix(X2)
  }
  for(i in 1:nrow(Q)){
    n11 <- tabulate(Z1[i, which(X1[i, ] == 1)], nbins = 3)
    n12 <- tabulate(Z2[i, which(X2[i, ] == 1)], nbins = 3)
    m <- n11 + n12
    # m <- sapply(m, function(x) max(1, x))
    Q[i, ] <- rdirichlet(1, 1 + m)
  }
  return(Q) 
}

Z_condc <- cmpfun(Z_cond)
P_condc <- cmpfun(P_cond)
Q_condc <- cmpfun(Q_cond)

# simulation
iters <- 2000
Z <- list(Z1, Z2)

pb <- progress_bar$new(total = burnin)
Q_avg <- list()
i <- 1
for(b in 1:iters){
  pb$tick()
  Z <- Z_condc(Z[[1]], Z[[2]], X1, X2, P, Q)
  P <- P_condc(P, X1, X2, Z1 = Z[[1]], Z2 = Z[[2]])
  Q <- Q_condc(Q, X1, X2, Z1 = Z[[1]], Z2 = Z[[2]])
  
  i <- i + 1
  if(i > 1000){
  Q_avg[[i]] <- data.table(id = 1:197, Q) 
  if(i %% 100 == 0){
    print(i)
    Q_med <- rbindlist(Q_avg)[, lapply(.SD, mean), by = id]
    Q_med <- Q_med[, -1]/apply(Q_med[, -1], 1, sum)
    Q_med[, race := haplotype[, 2]]
    plot <- ggtern(data = Q_med, aes(x = V1, y = V2, z = V3)) + 
      geom_point(aes(fill = race),
                 size = 2, 
                 shape = 21, 
                 color = "black") + 
      xlab("") + ylab("") + zlab("") +
      theme_legend_position('topleft') +
      labs(fill = "Race") 
    print(plot)
    }
  }
}

You are writing most of the code in global scope which is not really done in Julia because it inhibits most of the optimizations that makes Julia fast. The fact that you are using global scope so much means that you are hitting https://github.com/JuliaLang/julia/issues/28789 more than typical.

2 Likes

Maybe @bicycle1885 or someone else from biojulia can help.

Also @trappmartin who works on the turing PPL

1 Like

Just to be clear, since that discussion you linked to did seem to get a bit heated, I’m not suggesting that this behavior is wrong or bad in Julia. I’m new to the language and I’m excited about it and would love to have Julia as a tool my team uses. The code in this toy problem is stuff we may write and I’d like to know how to coach and fix my team’s attempts at incorporating Julia in addition to R.

With that, I’ve updated the functions a bit to resolve some of the scoping issues but still not getting separation into 3 clusters. Here are the main functions rewritten.

function Z_cond(X1, X2, P, Q)
      
    N = size(X1, 1)
    M = size(X1, 2)
    
    z1 = Vector{Float64}(undef, 3)
    z2 = Vector{Float64}(undef, 3)
    Z1 = Matrix(undef, N, M)
    Z2 = Matrix(undef, N, M)
    
    for n in 1:N
        for m in 1:M
            for j in 1:3
                z1[j] = Q[n, j] * (abs(X1[n, m] - 1) + P[m, j])
                z2[j] = Q[n, j] * (abs(X2[n, m] - 1) + P[m, j])
            end
            z1 = z1 ./ sum(z1)
            z2 = z2 ./ sum(z2)
            mult1 = Multinomial(1, z1)
            mult2 = Multinomial(1, z2)
            Z1[n, m] = findall(x -> x == 1, rand(mult1, 1))[1][1]
            Z2[n, m] = findall(x -> x == 1, rand(mult2, 1))[1][1]
        end
    end
    return Z1, Z2
end       

function P_cond(X1, X2, Z1, Z2)
    p = Vector{Float64}(undef, 3)
    M = size(X1, 2)
    P = Matrix(undef, M, 3)
    for i in 1:M
        n11 = freqtable(Z1[findall(x -> x == 1, X1[1]), i])            
        n01 = freqtable(Z1[findall(x -> x == 0, X1[1]), i])
        n12 = freqtable(Z2[findall(x -> x == 1, X2[1]), i])    
        n02 = freqtable(Z2[findall(x -> x == 0, X2[1]), i])
        
        n1 = Vector{Int64}(undef, 3)
        n0 =  Vector{Int64}(undef, 3)
        for j in 1:3
            n1[j] = get(n11, j, 0) + get(n12, j, 0)
            n0[j] = get(n01, j, 0) + get(n02, j, 0)
        end 
        
        @. p = rand(Beta(1 + n1, 1 + n0))
        p = p / sum(p)
        P[i, :] = p
    end
    return P
end

function Q_cond(X1, X2, Z1, Z2)
    X1 = convert(Array, X1)
    X2 = convert(Array, X2)
    N = size(X1, 1)
    Q = Matrix(undef, N, 3)
    
    for i in 1:N
        m = Vector{Int64}(undef, 3)
        m1 = freqtable(Z1[i, findall(x -> x == 1, X1[i, :])])
        m2 = freqtable(Z2[i, findall(x -> x == 1, X2[i, :])])
        for j in 1:3
           m[j] = get(m1, j, 0) + get(m2, j, 0)
        end    
        q = Dirichlet(1 .+ m) 
        Q[i, :] = transpose(rand(q, 1))
    end
   return Q
end

function gibb(X1, X2, n, m)
    N = size(X1, 1) # num. of obs
    M = size(X1, 2)  # num. of variables
    J = 3 # num. of latent populations
    
    iter = n + m
    m = 0
    
    dir = Dirichlet(ones(J)) 
    Q = transpose(rand(dir, N))
    Q_acc = deepcopy(Q)
    β = Beta() # equivalent to Beta(1, 1)
    P = hcat(rand(β, M), rand(β, M), rand(β, M))
    P = P ./ sum(P, dims = 2);
    
    # burn-in n iterations
    # draw and calculate mean on m iterations (post burn-in)
    for i in 1:iter
        Z = Z_cond(X1, X2, P, Q)
        P = P_cond(X1, X2, Z[1], Z[2])
        Q = Q_cond(X1, X2, Z[1], Z[2])
        if i > n
            m += 1
            Q_acc += Q
        end
    end
    return Z, P, Q, Q_acc/m
end

Sure, but I think it has been pretty much agreed upon that the scoping rules for global scope is a bit confusing and might need to change. My point was just that the way the code was structured might have caused you to hit this confusion a bit more than “average”.

Hi @spinkney,
I’m happy to take a look your implementation.

However, instead of implementing a Gibbs sampler by hand I would suggest you have a look at one of the existing PPLs in Julia. From my experience writing a MCMC sampler by hand is very error prune. But I understand that you want to try out Julia and that is more of an exercise.

I might’ve misread the PDF but the R code seems to be wrong. For example, when distribution G_com into two binary matrices X1, X2, we have the following

X1 <- G
X2 <- G
for(i in 1:nrow(ones)){
  X1[ones[[i, 1]], ones[[i, 2]]] <- X1_ones[[i]]
  X2[ones[[i, 1]], ones[[i, 2]]] <- X2_ones[[i]]
}
X1 <- X1[, lapply(.SD, as.integer)]
X2 <- X2[, lapply(.SD, as.integer)]

This should assert X1 + X2 = G_com but


> which(X1 + X2 != G_com)
 [1]   107  1904  2353  2374  2387  2422  2499  2531  3080  3141  5211  5269
[13]  6336  7469  7968  8384  8503  8546  9051  9685  9718  9742  9821 11415
[25] 13821 13836 13855 15277 15294 15316 15355 16767 17384 17467 18242 18255
[37] 18299 18383 19205 19368

So I am not really sure what’s going on here.

If I read it correct, G_com is a matrix with 0, 1, 2 values… this is split into a tuple X1, X2 according to the text. If G_com = 0 then X1 = 0, X2 = 0. If G_com = 1 then (X1 = 1, X2 = 0) OR (X1 = 0, X2 = 1), and finally if G_com = 2 then X1 = 1, X2 = 1.

Can you please verify the R code ?

It should be similar to setting some of the missing values to 0. Because this problem came from a Hastie book I thought he may have used his softImpute package to impute the missing values. This package doesn’t seem to give back integer values, I think all the missing values in the data may have been imputed with non-integer values. However, the code replaces these with 0 in the last line (not good programming).

X1 <- X1[, lapply(.SD, as.integer)]
X2 <- X2[, lapply(.SD, as.integer)]

element <- rep(0, 40)
G_com <- as.matrix(G_com)
test <- which(X1 + X2 != G_com)
for(i in 1:40){
   element[i] <- G_com[test[i, 1][[1]], test[i, 2][[1]]]
}
element
 [1] 0.009430842 0.053122270 0.064903376 0.030921227 0.029578536 0.027831544 0.025968560 0.033315496 0.023399275 0.031129861 0.008249561 0.008678403 0.032868111
[14] 0.003769389 0.051818871 0.030420410 0.009280796 0.008062181 0.075039673 0.006171473 0.004758174 0.005352866 0.006854543 0.037255259 0.059098031 0.060792429
[27] 0.050306816 0.052951985 0.039636823 0.056479963 0.052731924 0.091726351 0.049561444 0.050672482 0.035101460 0.040886051 0.050022185 0.034397406 0.009059716
[40] 0.027249554

> as.integer(element)
 [1] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

Yes, would prefer a Turing.jl implementation as I’m quite interested in using their discrete parameter estimation for something like this instead of thinking about marginalizing it out as in Stan.

1 Like