Save dispatched function for later call

Now you don’t want to hardcode out = 0, as that will cause type instability if op(a,b) is not an Int. It’s better to either restrict a and b to the same type T and use out = zero(T) for initialization (assuming op(a,b) also returns a value of type T), or use a function barrier and promote:

l(ops, a::T, b::S) where {t,S} = l(ops, promote(a,b)...)
function l(ops, a::T, b::T) where T
	out = zero(T)
	for op in ops
		out += op(a,b)
	end
	out
end

Though you’ll still want that op(a,b) returns a T, to prevent more type instability (or convert explicitly via convert(T, (op(a,b)))).

Well at least it actually has to do work now ^^ You may also want to check the output of @benchmark instead of just @btime. You’ll get a histogram of running times instead of just the minimum time reported by @btime, giving you a better picture of the tradeoffs.

Shameless plug:

Inlined execution of functions listed in an array is possible with FunctionWranglers.jl, it does the metaprogramming for you:

julia> w=FunctionWrangler(opp)
FunctionWrangler with 8 items: +, +, *, +, -, -, *, *,

julia> outs=zeros(Float64, length(opp))
8-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

julia> using BenchmarkTools

julia> @btime begin smap!($outs, $w, $a, $b); sum($outs) end
  9.306 ns (0 allocations: 0 bytes)
111.0

https://github.com/tisztamo/FunctionWranglers.jl

Catwalk.jl may also help here, and there is a list of packages for dispatch-speedup in the “alternatives” section of its docs:

https://tisztamo.github.io/Catwalk.jl/dev/#Alternatives

I hope this helps!

(edit: example added)

1 Like

Hmm… basically all I wanted to do is the same that smap!( ) does with some dependency from the computation graph.

I want to know how the hell does this work this fast… :smiley: This is sick! :smiley:

Your benchmark is still suboptimal. It’s type unstable deal to / operation (it will produce float number).

Here is the correct benchmark. I compare four cases here:
1.A sequence of jump, encode each opcodes with continuous integer (we don’t use symbol here, since symbol require additional loads to perform equality test).
2. FunctionWrappers, which uses Julia’s cfunction. It has some overheads to check whether inputs types match function signature.
3. Raw function pointers, which has no other overheads at all.
4. native dynamic dispatch

Some comments on the codes: It consists of three parts, the first part is a fix to a bug in FunctionWrappers.jl. It seems that in Julia > 1.6, LLVM doesn’t inline the llvmcall to assume and cause ~10 ns slowdown. The second part is an implementation of case 3. The third part is the actual benchmark code. Case 1 use if-elseif-else to match each opcode manually which case 2-3 directly call op(a,b) to evaluate the result.

To execute the code, just copy the code benchmark.jl and include that file.
The result is (tested on an old slow computer):

bundle of jump:  12.935 ns (0 allocations: 0 bytes)
function wrapper:  48.559 ns (0 allocations: 0 bytes)
raw pointer:  17.155 ns (0 allocations: 0 bytes)
dynamic dispatch:  362.976 ns (0 allocations: 0 bytes)

So using a function pointer is only slightly slower than the inlined call. And FunctionWrapper is 2x-3x slower than them, Dynamic dispatch is 20x-30x slower. (none of these results are unexcepted). Though the slowdown is not a great deal if dispatch doesn’t happen a lot.

Also, the benchmark can be made to be more accurate if we directly write assembly codes to bypass LLVM’s optimization (for example, we can manually construct a linear jump table to test whether it’s beneficial to use the jump table) . But currently this benchmark result should be enough.

benchmark.jl

# *** Part 1 ***
# A monkey patch to FunctionWrappers's assume function
# assume must have an alwaysinline attribute, otherwises it doesn't get inlined
using FunctionWrappers
import FunctionWrappers.FunctionWrapper
if VERSION >= v"1.6.0-DEV.663"
    @inline function assume(v::Bool)
        Base.llvmcall(
            ("""
             declare void @llvm.assume(i1)
             define void @fw_assume(i8) alwaysinline
             {
                 %v = trunc i8 %0 to i1
                 call void @llvm.assume(i1 %v)
                 ret void
             }
             """, "fw_assume"), Cvoid, Tuple{Bool}, v)
    end
else
    @inline function assume(v::Bool)
        Base.llvmcall(("declare void @llvm.assume(i1)",
                       """
                       %v = trunc i8 %0 to i1
                       call void @llvm.assume(i1 %v)
                       ret void
                       """), Cvoid, Tuple{Bool}, v)
    end
end
# method redefinition
@inline FunctionWrappers.assume(v::Bool) = Main.assume(v)

# *** Part 2, A simple wrapper around function pointers, to make life easier ***
# ArithBiOp{T} is a binary function with type  T x T -> T
struct ArithBiOp{T}
    fptr::Ptr{Nothing}
end

# get function pointer by look up code instance
function get_biop_ptr(f,::Type{T}) where T
    # triger compilation of the function
    
    m = which(f,(T,T)).specializations[1]
    if !isdefined(m,:cache)
        precompile(f,(T,T))   
    end
    @assert isdefined(m,:cache)
    # get the function pointer

    ptr =  m.cache.specptr
    @assert ptr != C_NULL
    return ArithBiOp{T}(ptr)
end

# unsafely call the function by following calling conversion
# this only works if the inputs are trivial enough, so we don't need to worry about GC
@inline function (op::ArithBiOp{T})(i1::T,i2::T) where T
    unsafe_call(op,i1,i2)
end

# assume is used to bypass null pointer checking.
@inline @generated function unsafe_call(op::ArithBiOp{T},i1::T,i2::T) where T
    :(fptr = op.fptr; assume(fptr != C_NULL);ccall(fptr,$T,($T,$T),i1,i2))
end

# ***Part 3: set up benchmark***
# we don't use divide here, since it's type unstable
encode_dict = Dict{Symbol,Int}([:ADD=>0,:SUBS=>1,:MUL=>2])
# symbol inputs
sym_ops = Symbol[:ADD, :ADD, :MUL, :ADD, :SUBS, :SUBS, :MUL, :MUL]
# encode inputs
int_ops = Int[encode_dict[i] for i in sym_ops]
# function inputs
f_ops = Function[+, +, *, +, -, -, *, *] 
# raw function pointers from Julia's generic function
ptr_ops = [get_biop_ptr(i,Int64) for i in f_ops]
# function wrappers
funcwrap_ops = [FunctionWrapper{Int, Tuple{Int, Int}}(op) for op in f_ops]

function condjump(ops,a,b)
    s = zero(typeof(a))
	for op in ops
		if op == 0
			s += +(a, b)
		elseif op == 1
			s += -(a, b)
		elseif op == 2
			s += *(a, b)
        end
	end
    return s
end

function direct_eval(ops,a,b)
    s = zero(typeof(a))
    for op in ops
        s += op(a,b)
    end
    return s
end

using BenchmarkTools
a=4
b=7

print("bundle of jump:");@btime condjump($int_ops,$a,$b);
print("function wrapper:");@btime direct_eval($funcwrap_ops,$a,$b);
print("raw pointer:");@btime direct_eval($ptr_ops,$a,$b);
print("dynamic dispatch:");@btime direct_eval($f_ops,$a,$b);
1 Like

Okay this FunctionWranglers crashes with 1024 operations. Interesting…
Also it seems like the compilation time scales linearly with the amount of op, which is pretty sad.

It sounds like it generate the whole operations one by one into a function that is called after that, that brings the speed and that could explain the crash with 1024 operations…

Wow that code! I need time to consume. :smiley:

I increased the amount of operation to 16 and the difference increased at the raw pointer too like 2-3x compared to the “bundle of jump”.

I think finally I get it. The key point here is that the compiler can’t generate performant code due to the FunctionWrappers/FunctionWranglers return type stability probably isn’t in the optimisation workflow. So I guess due to in normal case the return types of function pointers isn’t fixed, the coders didn’t focus on the optimisation of a code to deduce the stability of the return type of these function calls. So basically there is some type checking even tho we removed the / function. Also one more reason, the ops array could be any sort of return types, so it would be an unsafe operation not to check the type in runtime in this case. I don’t know if @anon56330260 you understand me? So basically the compiler cannot generate a code that could spare the type checking in that for loop before doing the += addition. Even if we call it with an input that could benefit from that optimisation to drop that type checking, the compiler mustn’t specialise and suppose it will get the an operation like that because that could cause very serious problem. So it isn’t about the problem of inlining and so on. Is this viable explanation? Maybe if we could specify that don’t check the types because it will be all right every time, but i don’t know if we have this.
Also it would be clever if I would check the native code… but no time for these details now. :frowning:

1 Like