Help with Turing.jl model using Categorical, Normal and ordered constraint

I’m trying to model AllenDowney’s three problems “How tall is A?” (How tall is A? – Probably Overthinking It ):

1) Suppose you meet an adult resident of the U.S. who is 170 cm tall.
     What is the probability that they are male?

2) Suppose I choose two U.S. residents at random and A is taller than B. 
     How tall is A?

3) In a room of 10 randomly chosen U.S. residents, A is the second tallest. 
    How tall is A? 
    And what is the probability that A is male?*

The first two problems was easy to model in Turing.jl (http://hakank.org/julia/turing/how_tall_is_a.jl ). Here’s the model of the second problem:

using Turing

@model function how_tall_is_2()
    male = 1
    female = 2
    gender = tzeros(2)
    height = tzeros(2)
    for p in 1:2
        gender[p] ~ Categorical(simplex([0.49,0.51]))
        height[p] ~ gender[p] == male ? Normal(178,7.7) : Normal(163,7.3)
    end
    true ~ Dirac(height[1] > height[2])    
end

model = how_tall_is_2()
chains = sample(model, PG(15), 1_000)
display(chains)

And it give the expected heights of:

   ...
   height[1]   176.4384    9.0135     0.2850    0.1777   824.6929    1.0007      179.5152
   height[2]   164.5469    8.4415     0.2669    0.2017   828.7247    0.9998      180.3928
   ...

However, the third problem, were we have 10 people instead of 2, is much harder to get correct. Here’s my model (about the same as the above, but n is now 10):

using Turing

@model function how_tall_is_3()
    male = 1
    female = 2
    n = 2
    gender = tzeros(n)
    height = tzeros(n)
    for p in 1:n
        gender[p] ~ Categorical(simplex([0.49,0.51]))
        height[p] ~ gender[p] == male ? Normal(178,7.7) : Normal(163,7.3)
    end

    for p in 1:n-1
        true ~ Dirac(height[p] > height[p+1])
        # height[p] > height[p+1] || begin Turing.@addlogprob! -Inf; end        
    end
 
end

model = how_tall_is_3()
chains = sample(model, MH(), 10_000)
# chains = sample(model, PG(15), 10_000)
# chains = sample(model, SMC(), 10_000)
# chains = sample(model, IS(), 10_000)
display(chains)

The correct answer should be that A (the second tallest person) is about 181.61 cm (and the tallest person about 186 cm).

But this model don’t give any ordered results. Here’s an (representative) example using MH() where the heights are all over the places instead of nicely ordered.

   height[1]   176.0685  
   height[2]   184.1998  
   height[3]   165.4559  
   height[4]   161.9992  
   height[5]   182.4204  
   height[6]   175.6319  
   height[7]   182.7167  
   height[8]   161.5589  
   height[9]   183.9556  
   height[10]   165.1349 
   ...

I’ve tested different samplers, e.g. HM(), PG(),SMC(), and IS() but all give some wrong answer, either with different but unordered heights, or heights that don’t differ much, e.g. PG() give all heights about 170 cm (i,e, the same as using Prior()) .

I also tested to use Gibbs mixing sampler, e.g. Gibbs(MH(:gender),NUTS(1000,0.65,:height)) but it throws errors; probably because the parameters are not :gender and :height, but arrays of genders and heights and I’m not sure how to state this.

Also, as shown in the model, I tested both using Dirac and Turing.@addlogprob! for observing the order, but there’s no difference between these two.

I guess that the real problem is the mix of Categorical and Normal and that these samplers cannot handle the “observation” of ordering properly.

Is there a (preferably simple) fix for this, or perhaps the problem should be modeled in a completely different way?

Here’s a model that does what you want:

using Turing

@model function how_tall_is_3(
    n,
    pfemale,
    h_mean_female,
    h_mean_male,
    h_std_female,
    h_std_male,
    height_A,
    isfemale_A,
    ::Type{T} = Float64,
) where {T}
    isfemale ~ filldist(Bernoulli(pfemale), n)

    height = Vector{T}(undef, n)
    for i in 1:n
        if isfemale[i]
            height[i] ~ Normal(h_mean_female, h_std_female)
        else
            height[i] ~ Normal(h_mean_male, h_std_male)
        end
    end
    id_A = sortperm(height)[n-1]
    height_A ~ Dirac(height[id_A])
    isfemale_A ~ Dirac(isfemale[id_A])
end

# set "observations" of A to `missing`, so that we generate them
mod = how_tall_is_3(10, 0.51, 163, 178, 7.3, 7.7, missing, missing)
# no inference necessary! We just sample directly from the prior-predictive distribution
chn = sample(mod, Prior(), 10_000)

Output:

Chains MCMC chain (10000×23×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
Wall duration     = 0.66 seconds
Compute duration  = 0.66 seconds
parameters        = height[10], height[1], height[2], isfemale[10], isfemale[5], isfemale[7], isfemale[4], isfemale[6], height_A, isfemale[8], height[7], height[8], isfemale[9], isfemale[3], height[4], height[6], height[9], isfemale_A, height[3], isfemale[2], isfemale[1], height[5]
internals         = lp

Summary Statistics
    parameters       mean       std   naive_se      mcse          ess      rhat   ess_per_sec 
        Symbol    Float64   Float64    Float64   Float64      Float64   Float64       Float64 

   isfemale[1]     0.5061    0.5000     0.0050    0.0048   10062.6371    1.0003    15269.5556
   isfemale[2]     0.5108    0.4999     0.0050    0.0047   10069.3941    0.9999    15279.8090
   isfemale[3]     0.5116    0.4999     0.0050    0.0049   10003.7008    0.9999    15180.1226
   isfemale[4]     0.5114    0.4999     0.0050    0.0049    9756.3594    0.9999    14804.7942
   isfemale[5]     0.4988    0.5000     0.0050    0.0051    9872.9610    0.9999    14981.7314
   isfemale[6]     0.5141    0.4998     0.0050    0.0061    9637.4286    0.9999    14624.3226
   isfemale[7]     0.5134    0.4998     0.0050    0.0046   10231.4234    0.9999    15525.6804
   isfemale[8]     0.5025    0.5000     0.0050    0.0047    9834.8171    0.9999    14923.8499
   isfemale[9]     0.5144    0.4998     0.0050    0.0048    9842.9033    0.9999    14936.1203
  isfemale[10]     0.5200    0.4996     0.0050    0.0052    9767.5500    0.9999    14821.7754
     height[1]   170.4011   10.5277     0.1053    0.1129    9857.0370    1.0000    14957.5675
     height[2]   170.2644   10.7043     0.1070    0.0996   10133.5496    1.0000    15377.1618
     height[3]   170.4807   10.6361     0.1064    0.1049    9740.9282    1.0002    14781.3781
     height[4]   170.2575   10.6044     0.1060    0.1070   10072.0987    0.9999    15283.9130
     height[5]   170.5118   10.5506     0.1055    0.1080    9737.4231    0.9999    14776.0594
     height[6]   170.1812   10.6326     0.1063    0.1153    9745.2381    0.9999    14787.9181
     height[7]   170.3461   10.6142     0.1061    0.1094    9893.0560    1.0001    15012.2247
     height[8]   170.4455   10.5213     0.1052    0.0958   10029.7619    0.9999    15219.6690
     height[9]   170.4502   10.6084     0.1061    0.0981    9863.6789    0.9999    14967.6462
    height[10]   170.1680   10.5865     0.1059    0.1128    9526.2742    0.9999    14455.6513
      height_A   181.3947    4.6701     0.0467    0.0453    9569.4982    0.9999    14521.2416
    isfemale_A     0.0908    0.2873     0.0029    0.0028    9956.6954    0.9999    15108.7942

Quantiles
    parameters       2.5%      25.0%      50.0%      75.0%      97.5% 
        Symbol    Float64    Float64    Float64    Float64    Float64 

   isfemale[1]     0.0000     0.0000     1.0000     1.0000     1.0000
   isfemale[2]     0.0000     0.0000     1.0000     1.0000     1.0000
   isfemale[3]     0.0000     0.0000     1.0000     1.0000     1.0000
   isfemale[4]     0.0000     0.0000     1.0000     1.0000     1.0000
   isfemale[5]     0.0000     0.0000     0.0000     1.0000     1.0000
   isfemale[6]     0.0000     0.0000     1.0000     1.0000     1.0000
   isfemale[7]     0.0000     0.0000     1.0000     1.0000     1.0000
   isfemale[8]     0.0000     0.0000     1.0000     1.0000     1.0000
   isfemale[9]     0.0000     0.0000     1.0000     1.0000     1.0000
  isfemale[10]     0.0000     0.0000     1.0000     1.0000     1.0000
     height[1]   151.0167   162.6419   170.2986   178.1220   190.5900
     height[2]   150.7162   162.2111   169.9460   178.2009   190.6453
     height[3]   150.9973   162.6603   170.1311   178.2281   190.7633
     height[4]   150.8729   162.4045   170.0385   178.1108   190.3861
     height[5]   151.2441   162.6010   170.2618   178.3460   190.7374
     height[6]   150.7536   162.2724   169.9650   177.9726   190.4018
     height[7]   150.4891   162.5470   170.2617   178.2245   190.2992
     height[8]   150.8853   162.7284   170.3147   178.1920   190.5763
     height[9]   151.0248   162.5023   170.1412   178.3231   190.8849
    height[10]   150.5369   162.2818   169.9072   177.9174   190.5773
      height_A   172.1184   178.2957   181.4849   184.5212   190.5233
    isfemale_A     0.0000     0.0000     0.0000     0.0000     1.0000
2 Likes

Thanks, @sethaxen .

I like your twist to skip the explicit ordering constraint and instead query about the order of the heights: id_A = sortperm(height)[n-1]

First I was a little confused that the listing of height has all about 170cm, but then understood what you do. And since I actually want the proper listing of all the (ordered) heights, I added height_ordered:

    height_ordered = Vector{T}(undef, n)
    perm = sortperm(height)
    for i in 1:n
        height_ordered[i] ~ Dirac(height[perm[n-i+1]])
    end

And now get this nice result:

   height_ordered[1]   186.5568 
   height_ordered[2]   181.2781 
   height_ordered[3]   177.6543 
   height_ordered[4]   174.4240 
   height_ordered[5]   171.4671 
   height_ordered[6]   168.5977 
   height_ordered[7]   165.7553 
   height_ordered[8]   162.7587 
   height_ordered[9]   159.3365 
  height_ordered[10]   154.5416

Again, thanks!