That’s worth opening an issue
Hmm, is this (stepwise_jitter=1.0) true?
As far as I remember StanSample.jl (and thus Stan.jl) for at least 2 years (probably 5 years or so) use below settings by default w.r.t stepwise_jitter
julia> bayesian_result_stan.model
Model name:
name = parameter_estimation_model
C++ threads per forked process:
num_threads = 8
use_cpp_chains = false
check_num_chains = true
C++ chains per forked process:
num_cpp_chains = 1
No of forked Julia processes:
num_julia_chains = 4
Actual number of chains:
num_chains = 4
Sample section:
num_samples = 1000
num_warmups = 1000
save_warmup = false
thin = 1
seed = -1
refresh = 100
init_bound = 2
Adapt section:
engaged = true
gamma = 0.05
delta = 0.8
kappa = 0.75
t0 = 10
init_buffer = 75
term_buffer = 50
window = 25
Algorithm section:
algorithm = hmc
NUTS section:
engine = nuts
max_depth = 10
Metric section:
metric = diag_e
stepsize = 1.0
stepsize_jitter = 0.0
Data and init files:
use_json = true
Stansummary section:
summary true
print_summary false
Other:
output_base = /Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp/parameter_estimation_model
tmpdir = /Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp
`/Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp/parameter_estimation_model sample num_samples=1000 num_warmup=1000 save_warmup=0 thin=1 adapt engaged=1 gamma=0.05 delta=0.8 kappa=0.75 t0=10 init_buffer=75 term_buffer=75 window=25 algorithm=hmc engine=nuts max_depth=10 metric=diag_e stepsize=1.0 stepsize_jitter=0.0 random seed=-1 init=2 id=1 data file=/Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp/parameter_estimation_model_data_1.json output file=/Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/tmp/parameter_estimation_model_chain_1.csv refresh=100
This runs in about 8 seconds in a “normal” setup":
julia> include("/Users/rob/.julia/dev/DiffEqBayesStan/examples/Fitzhugh-Nagumo/fn.jl");
8.037559 seconds (2.07 k allocations: 202.500 KiB)
I’ll rerun the timings from about a year ago later today, but an “old” plot can be found in DiffEqBayesStan.jl:
So. After a lot of digging, I’ve arrived at the following.
On Distributions@0.25.76 (and somewhere below), the model
@model function irt(y, i, p; I = maximum(i), P = maximum(p))
theta ~ filldist(Normal(), P)
beta ~ filldist(Normal(), I)
y ~ arraydist(BernoulliLogit(theta[p] - beta[i]))
return (; theta, beta)
end
achieves Stan-perf with ReverseDiff.jl in compiled mode (GitHub - TuringLang/TuringBenchmarking.jl).
On more recent Distributions.jl versions, it’s 3X slower: Performance regression for BernoulliLogit · Issue #1934 · TuringLang/Turing.jl · GitHub. That issue made (eventually) made me realize what’s actually going wrong: fast vs. slow in the reverse AD scenario is just a matter of whether or not you hit the “fast” broadcasting pullback (which is just using ForwardDiff.jl under the hood) or not.
Whether or not you do hit this “fast” path depends on a few things and the conditions are slightly different for different AD backends, but it really comes down to the following:
Broadcasting constructors is apparently not great for Zygote.jl and ReverseDiff.jl:
- ReverseDiff.jl sees a
UnionAll
in the broadcast statement, e.g. aNormal
without it’s type-parameters is aUnionAll
, and the type-inference breaks down causing a significant perf degradation (and, in this scenario, makes compilation of the tape useless).- On Distributions@0.25.76 there’s no explicit
BernoulliLogit
type, but it’s instead a function which diverts toTuring.BernoulliBinomial
, hence there’s noUnionAll
involved in the actual broadcast statement and all is good.
- On Distributions@0.25.76 there’s no explicit
- Zygote.jl will do some checks on the inputs to decide whether or not to take the ForwardDiff-path or not: Zygote.jl/broadcast.jl at master · FluxML/Zygote.jl · GitHub. If it decides to not take the ForwardDiff-path, it’s very slow. When a constructor is involved, e.g.
Normal
, the element-typeT
in this check will not be of typeUnion{Real,Complex}
and so it fails and moves on.
Addressing these, by, for example, using Faster `arraydist` with LazyArrays.jl by torfjelde · Pull Request #231 · TuringLang/DistributionsAD.jl · GitHub, you end up seeing the following:
julia> using LazyArrays # Necessary to trigger the faster paths using the above PR.
julia> # performant model
@model function irt(y, i, p; I = maximum(i), P = maximum(p))
theta ~ filldist(Normal(), P)
beta ~ filldist(Normal(), I)
y ~ arraydist(BroadcastArray(BernoulliLogit, theta[p] - beta[i]))
return (; theta, beta)
end
irt (generic function with 2 methods)
julia> # Instantiate
model = irt(y, i, p);
julia> # Make the benchmark suite.
suite = TuringBenchmarking.make_turing_suite(
model,
adbackends = [TuringBenchmarking.ForwardDiffAD{40}(), TuringBenchmarking.ReverseDiffAD{true}(), TuringBenchmarking.ZygoteAD()]
);
julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
tags: []
"linked" => 4-element BenchmarkTools.BenchmarkGroup:
tags: []
"evaluation" => Trial(536.721 μs)
"Turing.Essential.ReverseDiffAD{true}()" => Trial(780.870 μs)
"Turing.Essential.ForwardDiffAD{40, true}()" => Trial(137.089 ms)
"Turing.Essential.ZygoteAD()" => Trial(1.374 ms)
"not_linked" => 4-element BenchmarkTools.BenchmarkGroup:
tags: []
"evaluation" => Trial(537.782 μs)
"Turing.Essential.ReverseDiffAD{true}()" => Trial(780.371 μs)
"Turing.Essential.ForwardDiffAD{40, true}()" => Trial(137.479 ms)
"Turing.Essential.ZygoteAD()" => Trial(1.380 ms)
Stan is at ~1.2ms my laptop, so then it goes ReverseDiff > Stan > Zygote > ForwardDiff, where the first 3 are fairly close.
Note that we’re not 100% certain Faster `arraydist` with LazyArrays.jl by torfjelde · Pull Request #231 · TuringLang/DistributionsAD.jl · GitHub is the way to go, but this is quite the performance improvement so we’ll make sure to get it through somehow. Might just mean that if you really want to go “vrooom!” you need to write y ~ arraydist(BernoulliLogit, x)
rather than y ~ arraydist(BernoulliLogit.(x))
(we can automate this, but at the beginning I think it’s going to be opt-in by being slightly different).
That’s awesome, thanks for the effort!
I will certainly try it on a real world example and post the differences/improvement here.
For now I will just mark your post as the answer as this seems to be the way forward.