Improving performance of item response model in Turing.jl

That’s worth opening an issue

2 Likes

Hmm, is this (stepwise_jitter=1.0) true?

As far as I remember StanSample.jl (and thus Stan.jl) for at least 2 years (probably 5 years or so) use below settings by default w.r.t stepwise_jitter

julia> bayesian_result_stan.model

Model name:
  name =                    parameter_estimation_model

C++ threads per forked process:
  num_threads =             8
  use_cpp_chains =          false
  check_num_chains =        true

C++ chains per forked process:
  num_cpp_chains =          1

No of forked Julia processes:
  num_julia_chains =        4

Actual number of chains:
  num_chains =              4

Sample section:
  num_samples =             1000
  num_warmups =             1000
  save_warmup =             false
  thin =                    1
  seed =                    -1
  refresh =                 100
  init_bound =              2
Adapt section:
  engaged =                 true
  gamma =                   0.05
  delta =                   0.8
  kappa =                   0.75
  t0 =                      10
  init_buffer =             75
  term_buffer =             50
  window =                  25

Algorithm section:

  algorithm =               hmc

    NUTS section:
      engine =              nuts
      max_depth =           10

  Metric section:
    metric =                diag_e
    stepsize =              1.0
    stepsize_jitter =       0.0

Data and init files:
  use_json =                true

Stansummary section:
  summary                   true
  print_summary             false

Other:
  output_base =             /Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp/parameter_estimation_model
  tmpdir =                  /Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp

`/Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp/parameter_estimation_model sample num_samples=1000 num_warmup=1000 save_warmup=0 thin=1 adapt engaged=1 gamma=0.05 delta=0.8 kappa=0.75 t0=10 init_buffer=75 term_buffer=75 window=25 algorithm=hmc engine=nuts max_depth=10 metric=diag_e stepsize=1.0 stepsize_jitter=0.0 random seed=-1 init=2 id=1 data file=/Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp/parameter_estimation_model_data_1.json output file=/Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp/parameter_estimation_model_chain_1.csv refresh=100

This runs in about 8 seconds in a “normal” setup":

julia> include("/Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/fn.jl");
  8.037559 seconds (2.07 k allocations: 202.500 KiB)

I’ll rerun the timings from about a year ago later today, but an “old” plot can be found in DiffEqBayesStan.jl:

performance_2021

So. After a lot of digging, I’ve arrived at the following.

On Distributions@0.25.76 (and somewhere below), the model

@model function irt(y, i, p; I = maximum(i), P = maximum(p))
    theta ~ filldist(Normal(), P)
    beta ~ filldist(Normal(), I)
    y ~ arraydist(BernoulliLogit(theta[p] - beta[i]))

    return (; theta, beta)
end

achieves Stan-perf with ReverseDiff.jl in compiled mode (GitHub - TuringLang/TuringBenchmarking.jl).

On more recent Distributions.jl versions, it’s 3X slower: Performance regression for BernoulliLogit · Issue #1934 · TuringLang/Turing.jl · GitHub. That issue made (eventually) made me realize what’s actually going wrong: fast vs. slow in the reverse AD scenario is just a matter of whether or not you hit the “fast” broadcasting pullback (which is just using ForwardDiff.jl under the hood) or not.

Whether or not you do hit this “fast” path depends on a few things and the conditions are slightly different for different AD backends, but it really comes down to the following:

Broadcasting constructors is apparently not great for Zygote.jl and ReverseDiff.jl:

  • ReverseDiff.jl sees a UnionAll in the broadcast statement, e.g. a Normal without it’s type-parameters is a UnionAll, and the type-inference breaks down causing a significant perf degradation (and, in this scenario, makes compilation of the tape useless).
    • On Distributions@0.25.76 there’s no explicit BernoulliLogit type, but it’s instead a function which diverts to Turing.BernoulliBinomial, hence there’s no UnionAll involved in the actual broadcast statement and all is good.
  • Zygote.jl will do some checks on the inputs to decide whether or not to take the ForwardDiff-path or not: Zygote.jl/broadcast.jl at master · FluxML/Zygote.jl · GitHub. If it decides to not take the ForwardDiff-path, it’s very slow. When a constructor is involved, e.g. Normal, the element-type T in this check will not be of type Union{Real,Complex} and so it fails and moves on.

Addressing these, by, for example, using Faster `arraydist` with LazyArrays.jl by torfjelde · Pull Request #231 · TuringLang/DistributionsAD.jl · GitHub, you end up seeing the following:

julia> using LazyArrays  # Necessary to trigger the faster paths using the above PR.

julia> # performant model
       @model function irt(y, i, p; I = maximum(i), P = maximum(p))
           theta ~ filldist(Normal(), P)
           beta ~ filldist(Normal(), I)
           y ~ arraydist(BroadcastArray(BernoulliLogit, theta[p] - beta[i]))

           return (; theta, beta)
       end
irt (generic function with 2 methods)

julia> # Instantiate
       model = irt(y, i, p);

julia> # Make the benchmark suite.
       suite = TuringBenchmarking.make_turing_suite(
           model,
           adbackends = [TuringBenchmarking.ForwardDiffAD{40}(), TuringBenchmarking.ReverseDiffAD{true}(), TuringBenchmarking.ZygoteAD()]
       );

julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(536.721 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(780.870 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(137.089 ms)
          "Turing.Essential.ZygoteAD()" => Trial(1.374 ms)
  "not_linked" => 4-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "evaluation" => Trial(537.782 μs)
          "Turing.Essential.ReverseDiffAD{true}()" => Trial(780.371 μs)
          "Turing.Essential.ForwardDiffAD{40, true}()" => Trial(137.479 ms)
          "Turing.Essential.ZygoteAD()" => Trial(1.380 ms)

Stan is at ~1.2ms my laptop, so then it goes ReverseDiff > Stan > Zygote > ForwardDiff, where the first 3 are fairly close.

Note that we’re not 100% certain Faster `arraydist` with LazyArrays.jl by torfjelde · Pull Request #231 · TuringLang/DistributionsAD.jl · GitHub is the way to go, but this is quite the performance improvement so we’ll make sure to get it through somehow. Might just mean that if you really want to go “vrooom!” you need to write y ~ arraydist(BernoulliLogit, x) rather than y ~ arraydist(BernoulliLogit.(x)) (we can automate this, but at the beginning I think it’s going to be opt-in by being slightly different).

13 Likes

That’s awesome, thanks for the effort!

I will certainly try it on a real world example and post the differences/improvement here.
For now I will just mark your post as the answer as this seems to be the way forward.

1 Like