Gibbs sampler works in R doesn't in Julia (full Julia code provided) please help

Unfortunately, I don’t have the time to implement and test the model you are interested. However, the model you are interested in could be in the lines of the following code example. This might not run as is but with some minor corrections it should work for your purpose. Feel free to raise an issue on the GitHub page of Turing if you ran into problems.

using Turing, MCMCChain

@model popu_mixture(X, N, J, M, C; α = 1.0, a = 1.0, b = 1.0) = begin
	# Mixing weights for each subject (i \in {1, ..., N}),
	Q = tzeros(J, N)

	# Mixing weights for each population (j \in {1, ..., J}) 
	# and locus (m \in {1, ..., M}).
	P = tzeros(J, M)

	# Draw the mixing weights for each subject.
	for i in 1:N
		Q[:,i] ~ Dirichlet(J, α)
	end

	# Draw the mixing weights for each population and locus.
	for j in 1:J
		for m in 1:M
			P[j, m] ~ Beta(a, b)
		end
	end

	# LVs for each subject, copy and locus.
	Z = tzeros(N, C, M)

	# Observations.
	X = tzeros(N, C, M)

	for i in 1:N
		for c in 1:C
			for m in 1:M
				Z[i,c,m] ~ Categorical(Q[:,i])
				j = Z[i,c,m]
				X[i,c,m] ~ Bernoulli(P[j, m])
			end
		end
	end

	return Z
end

#
# ... Your code to extract and pre-process the data ...
#

J = 3 # Assuming three distinct groups.
C = 2 # Assuming two copies.

# Construct the model function an compile the model.
model_fun = popu_mixture(X, N, J, M, C)

# Run a particle Gibbs sampler on the model 
# for 10.000 iterations using 500 particles.
# (The number of iterations and particles might have to be adjusted.)
chain = sample(model_fun, PG(500,10_000))

# Summarize the sampling results.
describe(chain)

As a side note, this particular model has tons of discrete RVs and inference might be more efficient if you collapse out the discrete RVs as you already mentioned.

2 Likes

I got it to work! Though I couldn’t get the Turing implementation to work, it just spun and spun even with 1 particle and 1 iteration (I’m on Windows 7). I did get the original code to work though. The “get()” function only appends to a vector to make it the length desired. However, this was not getting the index of the population. For example, if only pop. 1 and 3 were part of X then the frequency of the alleles should have something in 1 and 3 but 0 in 2. Prior it would put data in 2 and set 3 to 0.

There was also an error in indexing X1/2 in P.


using CSVFiles, DataFrames, DataFramesMeta, CategoricalArrays, 
StatsBase, Distributions, Random, FreqTables, BenchmarkTools,
Turing, MCMCChain
data = DataFrame(load("https://web.stanford.edu/~hastie/CASI_files/DATA/haplotype.csv"));

Random.seed!(2342)
snps = hcat(getfield(data, :columns)[3:end]...)

a = similar(snps, Bool)
b = similar(snps, Bool)

for (index, snp) in enumerate(snps)
    a[index], b[index] = 
        if snp === missing
            false, false
        elseif snp == 0
            false, false
        elseif snp == 1
            x = rand(Bool)
            x, !x
        elseif snp == 2
            true, true
        else
            error("All snps must be 0, 1, 2, or missing")
        end
end

X1 = DataFrame(a * 1)
X2 = DataFrame(b * 1)
rename!(X1, [f => t for (f, t) = zip(names(X1), col_names)])
rename!(X2, [f => t for (f, t) = zip(names(X2), col_names)]);

 N = size(data_snp)[1] # num. of obs
 M = size(data_snp)[2]  # num. of variables
 J = 3 # num. of latent populations

 Q_prior = Dirichlet(ones(J)) 
 Q = transpose(rand(Q_prior, N))

 P_prior = Beta() # equivalent to Beta(1, 1)
 P = hcat(rand(P_prior, M), rand(P_prior, M), rand(P_prior, M))
 P = P ./ sum(P, dims = 2)

function Z_cond(X1, X2, P, Q)
      
    N = size(X1, 1)
    M = size(X1, 2)
    
    z1 = Vector{Float64}(undef, 3)
    z2 = Vector{Float64}(undef, 3)
    Z1 = Matrix(undef, N, M)
    Z2 = Matrix(undef, N, M)
    
    for n in 1:N
        for m in 1:M
            for j in 1:3
                z1[j] = Q[n, j] * (abs(X1[n, m] - 1) + P[m, j])
                z2[j] = Q[n, j] * (abs(X2[n, m] - 1) + P[m, j])
            end
            z1 = z1 ./ sum(z1)
            z2 = z2 ./ sum(z2)
            mult1 = Multinomial(1, z1)
            mult2 = Multinomial(1, z2)
            Z1[n, m] = findall(x -> x == 1, rand(mult1, 1))[1][1]
            Z2[n, m] = findall(x -> x == 1, rand(mult2, 1))[1][1]
        end
    end
    return Z1, Z2
end       

function P_cond(X1, X2, Z1, Z2)
    p = Vector{Float64}(undef, 3)
    M = size(X1, 2)
    P = Matrix(undef, M, 3)
    for i in 1:M
        n11 = freqtable(Z1[findall(x -> x == 1, X1[:, i]), i])            
        n01 = freqtable(Z1[findall(x -> x == 0, X1[:, i]), i])
        n12 = freqtable(Z2[findall(x -> x == 1, X2[:, i]), i])    
        n02 = freqtable(Z2[findall(x -> x == 0, X2[:, i]), i])
        
        n1 = to_fill(n11) + to_fill(n12)
        n0 = to_fill(n01) + to_fill(n02)
        
        @. p = rand(Beta(1 + n1, 1 + n0))
        p = p / sum(p)
        P[i, :] = p
    end
    return P
end

function Q_cond(X1, X2, Z1, Z2)
    X1 = convert(Array, X1)
    X2 = convert(Array, X2)
    N = size(X1, 1)
    Q = Matrix(undef, N, 3)
    
    for i in 1:N
        m = Vector{Int64}(undef, 3)
        m1 = freqtable(Z1[i, findall(x -> x == 1, X1[i, :])])
        m2 = freqtable(Z2[i, findall(x -> x == 1, X2[i, :])])
        
        m = to_fill(m1) + to_fill(m2)
        
        q = Dirichlet(1 .+ m) 
        Q[i, :] = transpose(rand(q, 1))
    end
   return Q
end



function gibb(X1, X2, n, m)
    N = size(X1, 1) # num. of obs
    M = size(X1, 2)  # num. of variables
    J = 3 # num. of latent populations
    
    iter = n + m
    m = 0
    
    dir = Dirichlet(ones(J)) 
    Q = transpose(rand(dir, N))
    Q_acc = deepcopy(Q)
    β = Beta() # equivalent to Beta(1, 1)
    P = hcat(rand(β, M), rand(β, M), rand(β, M))
    P = P ./ sum(P, dims = 2)
    # burn-in n iterations
    # draw and calculate mean on m iterations (post burn-in)
    for i in 1:iter
        Z = Z_cond(X1, X2, P, Q)
        P = P_cond(X1, X2, Z[1], Z[2])
        Q = Q_cond(X1, X2, Z[1], Z[2])
        if i > n
            m += 1
            Q_acc += Q
        end
    end
    return Z, P, Q, Q_acc/m
end

function to_fill(freq_table)
    fill = [0; 0; 0]
    for (i, j) in enumerate(intersect([1, 2, 3], names(freq_table)[1]))
           fill[j] = freq_table[i]
    end
    return fill
end   


burnin = gibb(X1, X2, 2000, 2000)

save("gibbs_output.csv", DataFrame(burnin[4]))

6 Likes

The behaviour you observed by Turing is probably related to an error in your model definition. We currently have problems to rethrow exceptions if a particle based sampler is used. But this issue should be fixed soon.

For now you can test your implementation by drawing a sample from the prior using
model_fun().

2 Likes