I’m struggling to get my head around some benchmarking and @code_llvm
inspecting that I’ve been doing. The background is I was curious what kind of overhead you get from creating an intermediate Normal
distribution in a custom function to calculate the PDF (as in R dnorm
), like so:
julia> using Distributions
julia> dnorm(x, μ=0., σ=1.) = pdf(Normal(μ, σ), x)
dnorm (generic function with 3 methods)
My intuition was that, since Normal
is immutable and just holds two floats (the mean and stddev) the compiler should be able to that out and compile away the wrapper all together. A quick look at the @code_warntype
seems to confirm this:
julia> @code_warntype dnorm(1., 0., 1.)
Body::Float64
1 1 ─ %1 = (Base.lt_float)(0.0, σ)::Bool │╻╷╷╷╷ Type
│ %2 = (Base.not_int)(%1)::Bool ││╻ Type
└── goto #3 if not %2 │││┃ macro expansion
2 ─ %4 = invoke Distributions.string("Normal"::String, ": the condition "::String, "σ > zero(σ)"::Vararg{String,N} where N, " is not satisfied.")::String ││││
│ %5 = %new(Core.ArgumentError, %4)::ArgumentError ││││╻ Type
│ (Distributions.throw)(%5) ││││
└── $(Expr(:unreachable)) ││││
3 ─ goto #4 ││││
4 ─ goto #5 ││
5 ─ %10 = (Base.sub_float)(x, μ)::Float64 ││╻╷╷ normpdf
│ %11 = (Base.div_float)(%10, σ)::Float64 │││╻ zval
│ %12 = (Base.mul_float)(%11, %11)::Float64 ││││╻╷ abs2
│ %13 = (Base.neg_float)(%12)::Float64 ││││╻ -
│ %14 = (Base.div_float)(%13, 2.0)::Float64 │││││╻ /
│ %15 = invoke StatsFuns.exp(%14::Float64)::Float64 ││││
│ %16 = (Base.mul_float)(%15, 0.3989422804014327)::Float64 │││││╻ *
│ %17 = (Base.div_float)(%16, σ)::Float64 │││╻ /
└── return %17 │
But. The performance is about 4x slower than when I implement the underlying calculation directly (or use the equivalent StatsFuns normpdf
function):
julia> dnorm_manual(x, μ, σ) = exp(abs2((x-μ)/σ) * -0.5) / (σ*sqrt(2π))
dnorm_manual (generic function with 1 method)
julia> @btime dnorm(1., 0., 1.)
6.704 ns (0 allocations: 0 bytes)
0.24197072451914337
julia> @btime dnorm_manual(1., 0., 1.)
1.637 ns (0 allocations: 0 bytes)
0.24197072451914337
julia> using StatsFuns; @btime normpdf(0., 1., 1.)
1.637 ns (0 allocations: 0 bytes)
0.24197072451914337
As far as I can tell, the only real difference here is the checks from the Normal
constructor (that the stddev is positive) (along with the fact that manual version isn’t not using the StatsFuns special constants 1/\sqrt{2\pi}):
julia> @code_warntype dnorm_manual(1., 0., 1.)
Body::Float64
1 1 ─ %1 = (Base.sub_float)(x, μ)::Float64 │╻ -
│ %2 = (Base.div_float)(%1, σ)::Float64 │╻ /
│ %3 = (Base.mul_float)(%2, %2)::Float64 ││╻ *
│ %4 = (Base.mul_float)(%3, -0.5)::Float64 │╻ *
│ %5 = invoke Main.exp(%4::Float64)::Float64 │
│ %6 = (Base.mul_float)(2.0, 3.141592653589793)::Float64 ││╻ *
│ %7 = (Base.Math.sqrt_llvm)(%6)::Float64 │╻ sqrt
│ %8 = (Base.mul_float)(σ, %7)::Float64 │╻ *
│ %9 = (Base.div_float)(%5, %8)::Float64 │╻ /
└── return %9 │
The @code_llvm
for the manual version looks sensible to me (knowing nothing about LLVM IR):
julia> @code_llvm dnorm_manual(1., 0., 1.)
; Function dnorm_manual
; Location: REPL[16]:1
define double @julia_dnorm_manual_35642(double, double, double) {
top:
; Function -; {
; Location: float.jl:397
%3 = fsub double %0, %1
;}
; Function /; {
; Location: float.jl:401
%4 = fdiv double %3, %2
;}
; Function abs2; {
; Location: number.jl:157
; Function *; {
; Location: float.jl:399
%5 = fmul double %4, %4
;}}
; Function *; {
; Location: float.jl:399
%6 = fmul double %5, -5.000000e-01
;}
%7 = call double @julia_exp_35233(double %6)
; Function *; {
; Location: float.jl:399
%8 = fmul double %2, 0x40040D931FF62705
;}
; Function /; {
; Location: float.jl:401
%9 = fdiv double %7, %8
;}
ret double %9
}
But for the Normal
wrapper version there’s a bunch of stuff that I don’t quite understand but seems related maybe to the GC, along with the parts that do the actual computation (that seem similar to what I’d expect based on the @code_llvm
for the manual version):
julia> @code_llvm dnorm(1., 0., 1.)
; Function dnorm
; Location: REPL[4]:1
define double @julia_dnorm_35643(double, double, double) {
top:
%3 = alloca %jl_value_t addrspace(10)*, i32 4
%gcframe = alloca %jl_value_t addrspace(10)*, i32 3
%4 = bitcast %jl_value_t addrspace(10)** %gcframe to i8*
call void @llvm.memset.p0i8.i32(i8* %4, i8 0, i32 24, i32 0, i1 false)
%thread_ptr = call i8* asm "movq %fs:0, $0", "=r"()
%ptls_i8 = getelementptr i8, i8* %thread_ptr, i64 -10920
%ptls = bitcast i8* %ptls_i8 to %jl_value_t***
; Function Type; {
; Location: /home/dave/.julia/dev/Distributions/src/univariate/continuous/normal.jl:34
; Function Type; {
; Location: /home/dave/.julia/dev/Distributions/src/univariate/continuous/normal.jl:30
; Function macro expansion; {
; Location: /home/dave/.julia/dev/Distributions/src/utils.jl:5
; Function >; {
; Location: operators.jl:286
; Function <; {
; Location: float.jl:452
%5 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 0
%6 = bitcast %jl_value_t addrspace(10)** %5 to i64*
store i64 2, i64* %6
%7 = getelementptr %jl_value_t**, %jl_value_t*** %ptls, i32 0
%8 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 1
%9 = bitcast %jl_value_t addrspace(10)** %8 to %jl_value_t***
%10 = load %jl_value_t**, %jl_value_t*** %7
store %jl_value_t** %10, %jl_value_t*** %9
%11 = bitcast %jl_value_t*** %7 to %jl_value_t addrspace(10)***
store %jl_value_t addrspace(10)** %gcframe, %jl_value_t addrspace(10)*** %11
%12 = fcmp ogt double %2, 0.000000e+00
;}}
br i1 %12, label %L10, label %L4
L4: ; preds = %top
; Location: /home/dave/.julia/dev/Distributions/src/utils.jl:6
%13 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %3, i32 0
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140298815038800 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %13
%14 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %3, i32 1
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140298792047504 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %14
%15 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %3, i32 2
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140298815038832 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %15
%16 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %3, i32 3
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140298792047552 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %16
%17 = call nonnull %jl_value_t addrspace(10)* @jsys1_string_24444(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140298857741968 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %3, i32 4)
; Function Type; {
; Location: boot.jl:276
%18 = bitcast %jl_value_t*** %ptls to i8*
%19 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
store %jl_value_t addrspace(10)* %17, %jl_value_t addrspace(10)** %19
%20 = call noalias nonnull %jl_value_t addrspace(10)* @jl_gc_pool_alloc(i8* %18, i32 1424, i32 16) #1
%21 = bitcast %jl_value_t addrspace(10)* %20 to %jl_value_t addrspace(10)* addrspace(10)*
%22 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)* addrspace(10)* %21, i64 -1
store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140298853852704 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)* addrspace(10)* %22
%23 = bitcast %jl_value_t addrspace(10)* %20 to %jl_value_t addrspace(10)* addrspace(10)*
store %jl_value_t addrspace(10)* %17, %jl_value_t addrspace(10)* addrspace(10)* %23, align 8
;}
%24 = addrspacecast %jl_value_t addrspace(10)* %20 to %jl_value_t addrspace(12)*
%25 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
store %jl_value_t addrspace(10)* %20, %jl_value_t addrspace(10)** %25
call void @jl_throw(%jl_value_t addrspace(12)* %24)
unreachable
L10: ; preds = %top
;}}}
; Function pdf; {
; Location: /home/dave/.julia/dev/Distributions/src/univariates.jl:525
; Function normpdf; {
; Location: /home/dave/.julia/packages/StatsFuns/0W2sM/src/distrs/norm.jl:8
; Function zval; {
; Location: /home/dave/.julia/packages/StatsFuns/0W2sM/src/distrs/norm.jl:4
; Function -; {
; Location: float.jl:397
%26 = fsub double %0, %1
;}
; Function /; {
; Location: float.jl:401
%27 = fdiv double %26, %2
;}}
; Function normpdf; {
; Location: /home/dave/.julia/packages/StatsFuns/0W2sM/src/distrs/norm.jl:7
; Function abs2; {
; Location: number.jl:157
; Function *; {
; Location: float.jl:399
%28 = fmul double %27, %27
;}}
; Function /; {
; Location: promotion.jl:316
; Function /; {
; Location: float.jl:401
%29 = fmul double %28, -5.000000e-01
;}}
%30 = call double @julia_exp_35233(double %29)
; Function *; {
; Location: promotion.jl:314
; Function *; {
; Location: float.jl:399
%31 = fmul double %30, 0x3FD9884533D43651
;}}}
; Function /; {
; Location: float.jl:401
%32 = fdiv double %31, %2
;}}}
%33 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 1
%34 = load %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %33
%35 = getelementptr %jl_value_t**, %jl_value_t*** %ptls, i32 0
%36 = bitcast %jl_value_t*** %35 to %jl_value_t addrspace(10)**
store %jl_value_t addrspace(10)* %34, %jl_value_t addrspace(10)** %36
ret double %32
}
So I guess I’m wondering where my intuitions have led me astray, and (possibly orthogonally) where the performance hit from using the Normal
in dnorm
vs. dnorm_manual
is coming from.