Making Turing Fast with large numbers of parameters?

Yes, the tape-compilation will fully eliminate the abstract interpretation pushing forward the tape generation. In [2109.12449] AbstractDifferentiation.jl: Backend-Agnostic Differentiable Programming in Julia we show that on some problems with heavy scalar usage it’s a good order or 2 magnitudes of difference. You have to grab a tape from a call in order to do it though, like:

Notice that in the DiffEq heuristics there’s a Cassette pass hasbranching that checks to see if the code has any branching since tape compilation is only compatible with static representations.

Is the:

this:

using ReverseDiff
using Turing
setrdcache(true)

?

1 Like

It seems like Turing doesn’t expose the tape/compilation directly to the user but rather is supposed to do this internally using the “setrdcache(true)”

According to the docs:
" When using ReverseDiff , to compile the tape only once and cache it for later use, the user needs to load Memoization.jl first with using Memoization then call Turing.setrdcache(true) . However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine. To empty the cache, you can call Turing.emptyrdcache() ." ( Automatic Differentiation )

Note that I DID do this in my original example. However I’m wondering if something I was doing before got in the way of proper compilation (such as a runtime if statement)

I am very late to the party here…but glad to see the model is now much faster than in the OP. If you really want to optimise the hell out of your model, you can always do the following:

  1. Define a custom distribution (e.g. in case of hierarchical models).
  2. Define a custom AD rule for the logpdf computation. This could also use AD to define the rule and then exploit properties such as the sparsity of the Jacobian to accelerate the rule.

We should probably have a blog post on that. It’s at this point that AD stops being Automatic Differentiation though and starts being Annoying Differentiation. But if you want extreme performance (at least until Diffractor or Enzyme comes along) you may have to resort to such extreme measures.

8 Likes

My favourite quote of the month

1 Like

Yeah, for me the great thing about Turing is the ease with which you can express models using arbitrary Julia code (user defined or from other great packages). However, when it comes to “extreme speed” I think there is a lot of knowledge kicking about inside the heads of a few experts that definitely needs to be shared with the rest of us.

6 Likes

Yes, and the fact that Julia has such great semantics, things like “view” and broadcasting/iteration and distributions as Types and such are wonderful to have for building models.

Definitely agree with this. I feel like I frequently see great Turing performance tips here and on Slack that are apparently well-known to the experts, but not that well-documented elsewhere.

3 Likes

It’s a bit difficult to find, but PRs to the performance section of the website can be made against the Turing.jl/performancetips.md at master · TuringLang/Turing.jl · GitHub file.

3 Likes

So I’m coming back to this. I have a model where if I use a MvNormal the whole thing samples in like a minute or two. YAY. But MvNormal isn’t ideal, because for example the data can’t be negative, and the predictions can’t be negative.

I’ve tried using something like

data ~ arraydist([Gamma(...) for i in 1:n])
predictions ~ arraydist([Gamma(...) for i in 1:k])

and I’ve tried truncated normals and various other things. The sampling via NUTS goes from a minute or two, to approximately infinity. For example I let it run for 40 minutes once and it never even started filling in the “thermometer” or calculated estimate of the time remaining. it basically did ZERO in 40 minutes.

This was all using ReverseDiff with Memoization and the rdcache.

I tried using Zygote but it crashed with an obscure error similar to what’s listed above. I tried ForwardDiff and it also didn’t go anywhere in many minutes.

This model (which I can’t copy here) has something like on the order of 500 parameters. How can I use anything other than a MvNormal on large parameter models? I can probably produce a MWE tomorrow.

I wasn’t able to make a MWE that got super slow… (got one… see below) Here’s what I tried. The problem is this: you have some staff and some patients, the staff come inside and get patients and take them outside to the garden throughout the day. Sometimes they use a wheelchair, sometimes they walk. The staff/patient pairings are random. A pressure plate in the exit door can tell you how much they weigh in total, and whether a wheelchair was used… from some observations of random pairings, estimate the weight of each patient and staff member as well as the weight of the wheelchair.

This has the flavor of the kind of model I’m working with, though it doesn’t match exactly, and unfortunately, it runs perfectly fine in both versions (the MvNormal and the version using arraydist of Gammas). In my real model when I switch to arraydist and/or truncated distributions and such it slows to a standstill (makes no progress at all in tens of minutes).

I’m going to work on this example problem and try to make it bork…

## every few hours a random staff member comes and gets a random
## patient to bring them outside to a garden through a door that has a
## scale. Sometimes using a wheelchair, sometimes not. knowing the
## total weight of the two people and the wheelchair plus some errors
## (from the scale measurements), infer the individual weights of all
## individuals and the weight of the wheelchair.

nstaff = 100
npat = 100
staffids = collect(1:nstaff)
patientids = collect(1:npat)
staffweights = rand(Normal(150,30),length(staffids))
patientweights = rand(Normal(150,30),length(staffids))
wheelchairwt = 15
nobs = 300

data = DataFrame(staff=rand(staffids,nobs),patient=rand(patientids,nobs))
data.usewch = rand(0:1,nobs)
data.totweights = [staffweights[data.staff[i]] + patientweights[data.patient[i]] for i in 1:nrow(data)] .+ data.usewch .* wheelchairwt .+ rand(Normal(0.0,20.0),nrow(data))


@model function estweights(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    staffweights ~ filldist(Normal(150,30),nstaff)
    patientweights ~ filldist(Normal(150,30),npatients)
    
    totweight ~ MvNormal(view(staffweights,staffid) .+ view(patientweights,patientid) .+ usewch .* wcwt,20.0)
end



@model function estweights2(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    
    totweight ~ arraydist([Gamma(15,(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt)/14) for i in 1:length(totweight)])
end



ch1 = sample(estweights(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)


ch2 = sample(estweights2(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)


AHA! This seems to do it:



@model function estweights3(nstaff,staffid,npatients,patientid,usewch,totweight)
    wcwt ~ Gamma(20.0,15.0/19)
    measerr ~ Gamma(10.0,20.0/9)
    staffweights ~ filldist(truncated(Normal(150,30),90.0,Inf),nstaff)
    patientweights ~ filldist(truncated(Normal(150,30),90.0,Inf),npatients)
    
    totweight ~ arraydist([truncated(Normal(staffweights[staffid[i]] + patientweights[patientid[i]] + usewch[i] * wcwt, measerr),0.0,Inf) for i in 1:length(totweight)])
end


ch2 = sample(estweights3(nstaff,data.staff,npat,data.patient,data.usewch,data.totweights),NUTS(500,.75),1000)


That has been sitting saying Sampling 0% ... EST: N/A for a couple minutes.

Since I’ve got a MWE I’m going to tag a few people who replied above and see if any of them have an idea… @rikh @EvoArt @mohamed82008 @sethaxen obviously anyone is welcome to jump in to the party, but if any of you have ideas I’d appreciate it.

I’ve been lurking since I’ve used Stan for years but trying to switch to Julia. In my experience you should distinguish between the time it takes to do a gradient evaluation and to reach 200 ESS. If the former is slow you can address it with various optimizations in Julia, if the latter is slow you need to reparametrize your model and you might be served by spelunking in the Stan forums.

Concretely, you added another Gamma hyper parameter and truncated distributions and it takes longer to run. It seems unlikely this made the gradient so much more costly to evaluate. This usually means you made posterior geometry much more complex, ie parameters are nonlinearly correlated, truncation has the sampler chasing mass in strange places etc. That’s something to look at perhaps with diagnostic tools. I don’t yet know Turing.jl well, but with Stan, I’d save the warmup iterations and look at number of leapfrog steps, step size and acceptance while it’s running. In summary HMC/NUTS performance is hard to predict, since the order of magnitude of runtime is both data and model dependent.

1 Like

This goes so slow that it never does a single sample in essentially “infinite” time (ie. 40 minutes and it hasn’t gone past 0%). Because of that it’s undoubtedly not just a geometry issue, this has to be something where it’s 4 or 5 orders of magnitude slower to evaluate the model. I believe this can occur when the ReverseDiff is unable to compile its “tape”. But I don’t know whether or why that’s really happening. I’m hoping someone with internals knowledge can say something about that.

With Stan, everything is compiled to C++, so if it compiles then it’ll run at machine-code speed. I believe Julia falls back to interpreting the code under some circumstances that I seem to trigger, but I don’t know how or why I’m triggering that.

I completely understand, the first number I look at when running a Stan model is the grad eval time (which CmdStan prints first). This can vary a lot but it helps distinguish whether I should work on the coding or the math. If you can get a similar number from Turing, you’d be better positioned to understand what’s happening. Since it’s just Julia, shouldn’t you be able to @btime your grad eval?

I tried to run the variants @dlakelan posted, while the first one runs fine, with the 3rd variant, I see a lot of errors,

┌ Warning: The current proposal will be rejected due to numerical error(s).
│   isfinite.((θ, r, ℓπ, ℓκ)) = (true, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47

which suggest an initialisation problem. This is another common one in Stan as well: if you can’t get off on the right foot, the chain just isn’t going to make progress.

It’s a little mysterious how that works, since Turing kinda hides the internals. when you call estweights3(…) it returns an object of type DynamicPPL.Model. I honestly don’t know how to evaluate either the density or the gradient based on having that (it’s all done “inside” the “sample” function). I mean, I could chase it down by reading through code, but I hope someone who already has the appropriate background to know what to do could maybe provide some boost here.

Also, it might be good to have some docs on the Turing.jl documentation that discusses how to time evaluations etc.

I just opened Easy way to get gradient evaluation timing · Issue #1721 · TuringLang/Turing.jl · GitHub. It’s low hanging fruit, I’d guess, and I couldn’t find it in the docs.

1 Like

Removing the distribution truncations from the third version leads to a model which runs fine. I’d definitely consider the MWE an initialisation/parametrisation problem even if it could run faster.

It’s very possible that the truncation is what’s forcing it to do interpretation rather than compilation though. I often get those initialization issues on other models that still run fast eventually (once they get going).

That’s a good point. Truncations would introduce branches in the code which IIRC are troublesome for some AD systems.