So. After a lot of digging, I’ve arrived at the following.
On Distributions@0.25.76 (and somewhere below), the model
@model function irt(y, i, p; I = maximum(i), P = maximum(p))
theta ~ filldist(Normal(), P)
beta ~ filldist(Normal(), I)
y ~ arraydist(BernoulliLogit(theta[p] - beta[i]))
return (; theta, beta)
end
achieves Stan-perf with ReverseDiff.jl in compiled mode (GitHub - TuringLang/TuringBenchmarking.jl).
On more recent Distributions.jl versions, it’s 3X slower: Performance regression for BernoulliLogit · Issue #1934 · TuringLang/Turing.jl · GitHub. That issue made (eventually) made me realize what’s actually going wrong: fast vs. slow in the reverse AD scenario is just a matter of whether or not you hit the “fast” broadcasting pullback (which is just using ForwardDiff.jl under the hood) or not.
Whether or not you do hit this “fast” path depends on a few things and the conditions are slightly different for different AD backends, but it really comes down to the following:
Broadcasting constructors is apparently not great for Zygote.jl and ReverseDiff.jl:
- ReverseDiff.jl sees a
UnionAll
in the broadcast statement, e.g. aNormal
without it’s type-parameters is aUnionAll
, and the type-inference breaks down causing a significant perf degradation (and, in this scenario, makes compilation of the tape useless).- On Distributions@0.25.76 there’s no explicit
BernoulliLogit
type, but it’s instead a function which diverts toTuring.BernoulliBinomial
, hence there’s noUnionAll
involved in the actual broadcast statement and all is good.
- On Distributions@0.25.76 there’s no explicit
- Zygote.jl will do some checks on the inputs to decide whether or not to take the ForwardDiff-path or not: Zygote.jl/broadcast.jl at master · FluxML/Zygote.jl · GitHub. If it decides to not take the ForwardDiff-path, it’s very slow. When a constructor is involved, e.g.
Normal
, the element-typeT
in this check will not be of typeUnion{Real,Complex}
and so it fails and moves on.
Addressing these, by, for example, using Faster `arraydist` with LazyArrays.jl by torfjelde · Pull Request #231 · TuringLang/DistributionsAD.jl · GitHub, you end up seeing the following:
julia> using LazyArrays # Necessary to trigger the faster paths using the above PR.
julia> # performant model
@model function irt(y, i, p; I = maximum(i), P = maximum(p))
theta ~ filldist(Normal(), P)
beta ~ filldist(Normal(), I)
y ~ arraydist(BroadcastArray(BernoulliLogit, theta[p] - beta[i]))
return (; theta, beta)
end
irt (generic function with 2 methods)
julia> # Instantiate
model = irt(y, i, p);
julia> # Make the benchmark suite.
suite = TuringBenchmarking.make_turing_suite(
model,
adbackends = [TuringBenchmarking.ForwardDiffAD{40}(), TuringBenchmarking.ReverseDiffAD{true}(), TuringBenchmarking.ZygoteAD()]
);
julia> run(suite)
2-element BenchmarkTools.BenchmarkGroup:
tags: []
"linked" => 4-element BenchmarkTools.BenchmarkGroup:
tags: []
"evaluation" => Trial(536.721 μs)
"Turing.Essential.ReverseDiffAD{true}()" => Trial(780.870 μs)
"Turing.Essential.ForwardDiffAD{40, true}()" => Trial(137.089 ms)
"Turing.Essential.ZygoteAD()" => Trial(1.374 ms)
"not_linked" => 4-element BenchmarkTools.BenchmarkGroup:
tags: []
"evaluation" => Trial(537.782 μs)
"Turing.Essential.ReverseDiffAD{true}()" => Trial(780.371 μs)
"Turing.Essential.ForwardDiffAD{40, true}()" => Trial(137.479 ms)
"Turing.Essential.ZygoteAD()" => Trial(1.380 ms)
Stan is at ~1.2ms my laptop, so then it goes ReverseDiff > Stan > Zygote > ForwardDiff, where the first 3 are fairly close.
Note that we’re not 100% certain Faster `arraydist` with LazyArrays.jl by torfjelde · Pull Request #231 · TuringLang/DistributionsAD.jl · GitHub is the way to go, but this is quite the performance improvement so we’ll make sure to get it through somehow. Might just mean that if you really want to go “vrooom!” you need to write y ~ arraydist(BernoulliLogit, x)
rather than y ~ arraydist(BernoulliLogit.(x))
(we can automate this, but at the beginning I think it’s going to be opt-in by being slightly different).