Hi Everyone,
I’m a director of data science and I’m checking out Julia as a tool for my group to use. I chose an example from the Computer Age Statistical Inference on population mixtures with gene data. The particular example can be found on page 256 of the linked pdf (https://web.stanford.edu/~hastie/CASI_files/PDF/casi.pdf). The data is freely available and the Julia code I provide below should enable you to easily get it into Julia.
The problem is to find what populations the individuals come from without using the labels. So it’s an unsupervised learning problem (though the actual labels are there to see how you’re doing). There are 100 genes that are coded as 0, 1, 2. The value is determined by two phenotypes X1 and X2. If X1 and X2 are both 0 then the value is 0, both 1 then 2 and if only one of the Xs is 1 then the value is 1. The estimation procedure is described in detail in the book and involves Gibbs sampling. My test was to see how my unoptimized R code compares to my unoptimized Julia code.
I’m very comfortable with R and was able to basically get what they have in the book (first image below). I’ve done a translation to Julia and the output is quite bad. I’m not sure where the issue is and I’m hoping some you guys can help out. On the plus side, the time to run is ~30x faster in Julia.
Here’s the mixture I got from R:
My attempt in Julia results in:
using CSVFiles, DataFrames, DataFramesMeta, CategoricalArrays, StatsBase, Distributions, Random, FreqTables, BenchmarkTools
data = DataFrame(load("https://web.stanford.edu/~hastie/CASI_files/DATA/haplotype.csv"))
for col in names(data)[3:102]
data[ismissing.(data[col]), col] = 0
end
data_snp = data[names(data)[3:102]]
Random.seed!(2342)
function prob_recode(x)
bin = Binomial()
local where_one, where_x, x1
where_x = findall(y -> y == 1, x)
where_two = findall(y -> y == 2, x)
x1 = copy(x)
x2 = copy(x)
x1[where_x] = rand(bin, length(where_x))
where_x1 = findall(y -> y == 1, x1)
other_ones = setdiff(where_x, where_x1)
x2[where_x1] = fill(0, length(where_x1))
x2[other_ones] = fill(1, length(other_ones))
x1[where_two] = fill(1, length(where_two))
x2[where_two] = fill(1, length(where_two))
x, x1, x2
end
X = colwise(prob_recode, data_snp);
col_names = names(data_snp)
a = hcat(X[1][:]...)[:, 2]
b = hcat(X[1][:]...)[:, 3]
for i in 2:100
a = hcat(a, hcat(X[i][:]...)[:, 2])
b = hcat(b, hcat(X[i][:]...)[:, 3])
end
X1 = DataFrame(a)
X2 = DataFrame(b)
rename!(X1, [f => t for (f, t) = zip(names(X1), col_names)])
rename!(X2, [f => t for (f, t) = zip(names(X2), col_names)]);
N = size(data_snp)[1] # num. of obs
M = size(data_snp)[2] # num. of variables
J = 3 # num. of latent populations
Q_prior = Dirichlet(ones(J))
Q = transpose(rand(Q_prior, N))
P_prior = Beta() # equivalent to Beta(1, 1)
P = hcat(rand(P_prior, N), rand(P_prior, N), rand(P_prior, N))
P = P ./ sum(P, dims = 2);
Z1 = Array{Float64}(undef, N, M)
Z2 = Array{Float64}(undef, N, M)
mult = Multinomial(1, Q[1, :])
z1 = rand(mult, M)
z2 = rand(mult, M)
for n in 1:N
for i in 1:M
Z1[n, i] = findall(x -> x == 1, z1[:, i])[1]
Z2[n, i] = findall(x -> x == 1, z2[:, i])[1]
end
end
Z = Z1, Z2
function Z_cond(Z1, Z2, X1, X2, P, Q)
N = nrow(X1)
M = size(Z1, 2)
z1 = Vector{Float64}(undef, 3)
z2 = Vector{Float64}(undef, 3)
for n in 1:N
for m in 1:M
for j in 1:3
z1[j] = Q[n, j] * (abs(X1[n, m] - 1) + P[m, j])
z2[j] = Q[n, j] * (abs(X2[n, m] - 1) + P[m, j])
end
z1 = z1 ./ sum(z1)
z2 = z2 ./ sum(z2)
mult1 = Multinomial(1, z1)
mult2 = Multinomial(1, z2)
Z1[n, m] = findall(x -> x == 1, rand(mult1, 1))[1][1]
Z2[n, m] = findall(x -> x == 1, rand(mult2, 1))[1][1]
end
end
return Z1, Z2
end
function P_cond(P, X1, X2, Z1, Z2)
p = Vector{Float64}(undef, 3)
for i in 1:size(P, 2)
n11 = freqtable(Z1[findall(x -> x == 1, X1[1]), i])
n01 = freqtable(Z1[findall(x -> x == 0, X1[1]), i])
n12 = freqtable(Z2[findall(x -> x == 1, X2[1]), i])
n02 = freqtable(Z2[findall(x -> x == 0, X2[1]), i])
n1 = Vector{Int64}(undef, 3)
n0 = Vector{Int64}(undef, 3)
for j in 1:3
n1[j] = get(n11, j, 0) + get(n12, j, 0)
n0[j] = get(n01, j, 0) + get(n02, j, 0)
end
p1 = Beta(1 + n1[1], 1 + n0[1])
p2 = Beta(1 + n1[2], 1 + n0[2])
p3 = Beta(1 + n1[3], 1 + n0[3])
p[1] = rand(p1, 1)[1]
p[2] = rand(p2, 1)[1]
p[3] = rand(p3, 1)[1]
p = p / sum(p)
P[i, :] = p
end
return P
end
function Q_cond(Q, X1, X2, Z1, Z2)
X1 = convert(Array, X1)
X2 = convert(Array, X2)
for i in 1:size(Q)[1]
m = Vector{Int64}(undef, 3)
m1 = freqtable(Z1[i, findall(x -> x == 1, X1[i, :])])
m2 = freqtable(Z2[i, findall(x -> x == 1, X2[i, :])])
for j in 1:3
m[j] = get(m1, j, 0) + get(m2, j, 0)
end
q = Dirichlet(1 .+ m)
Q[i, :] = transpose(rand(q, 1))
end
return Q
end
function gibb(Z, P, Q, n)
for i in 1:n
Z = Z_cond(Z[1], Z[2], X1, X2, P, Q)
P = P_cond(P, X1, X2, Z[1], Z[2])
Q = Q_cond(Q, X1, X2, Z[1], Z[2])
end
return Z, P, Q
end
burnin = gibb(Z, P, Q, 1000)
function gibb_acc(Z, P, Q, n)
Q_acc = copy(Q)
for i in 1:n
Z = Z_cond(Z[1], Z[2], X1, X2, P, Q)
P = P_cond(P, X1, X2, Z[1], Z[2])
Q = Q_cond(Q, X1, X2, Z[1], Z[2])
Q_acc += Q
end
return Q_acc/n, Q
end
out = gibb_acc(burning[1], burning[2], burning[3], 500)