I wasn’t able to make a MWE that got super slow… (got one… see below) Here’s what I tried. The problem is this: you have some staff and some patients, the staff come inside and get patients and take them outside to the garden throughout the day. Sometimes they use a wheelchair, sometimes they walk. The staff/patient pairings are random. A pressure plate in the exit door can tell you how much they weigh in total, and whether a wheelchair was used… from some observations of random pairings, estimate the weight of each patient and staff member as well as the weight of the wheelchair.
This has the flavor of the kind of model I’m working with, though it doesn’t match exactly, and unfortunately, it runs perfectly fine in both versions (the MvNormal and the version using arraydist of Gammas). In my real model when I switch to arraydist and/or truncated distributions and such it slows to a standstill (makes no progress at all in tens of minutes).
I’m going to work on this example problem and try to make it bork…
## every few hours a random staff member comes and gets a random
## patient to bring them outside to a garden through a door that has a
## scale. Sometimes using a wheelchair, sometimes not. knowing the
## total weight of the two people and the wheelchair plus some errors
## (from the scale measurements), infer the individual weights of all
## individuals and the weight of the wheelchair.
nstaff = 100
npat = 100
staffids = collect(1:nstaff)
patientids = collect(1:npat)
staffweights = rand(Normal(150,30),length(staffids))
patientweights = rand(Normal(150,30),length(staffids))
wheelchairwt = 15
nobs = 300
data = DataFrame(staff=rand(staffids,nobs),patient=rand(patientids,nobs))
data.usewch = rand(0:1,nobs)
data.totweights = [staffweights[data.staff[i]] + patientweights[data.patient[i]] for i in 1:nrow(data)] .+ data.usewch .* wheelchairwt .+ rand(Normal(0.0,20.0),nrow(data))
@model function estweights(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
staffweights ~ filldist(Normal(150,30),nstaff)
patientweights ~ filldist(Normal(150,30),npatients)
totweight ~ MvNormal(view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt,20.0)
end
@model function estweights2(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
totweight ~ arraydist([Gamma(15,(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt)/14) for i in 1:length(totweight)])
end
ch1 = sample(estweights(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
ch2 = sample(estweights2(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
AHA! This seems to do it:
@model function estweights3(nstaff,staffid,npatients,patientid,usewch,totweight)
wcwt ~ Gamma(20.0,15.0/19)
measerr ~ Gamma(10.0,20.0/9)
staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
totweight ~ arraydist([truncated(Normal(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt, measerr),0.0,Inf) for i in 1:length(totweight)])
end
ch2 = sample(estweights3(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)
That has been sitting saying Sampling 0% ... EST: N/A
for a couple minutes.
Since I’ve got a MWE I’m going to tag a few people who replied above and see if any of them have an idea… @rikh @EvoArt @mohamed82008 @sethaxen obviously anyone is welcome to jump in to the party, but if any of you have ideas I’d appreciate it.