Save dispatched function for later call

@rdeits This is exactly what I was looking for indeed. I can’t believe this really exists already. :smiley: Once I really have to give back to the community, this is just crazy…

@dlakelan the function is already compiled. The key is to call the appropriate one, but I want to spare the dynamic dispatch. This is hyper optimalisation… sort of. But actually dispatching is very pricy so eventually this is like a 100x speed up in this simple case.

1 Like

I have only one observation yet to note.

If I compare speeds:


function l(ops, a, b)
	for op in ops
		op(a,b)
	end
end

a=4
b=7
ops_s=[:ADD, :ADD, :MUL, :ADD, :SUBS, :SUBS, :MUL, :MUL]
opp=[+, +, *, +, -, -, *, *]
w_op=[FunctionWrapper{Int, Tuple{Int, Int}}(op) for op in opp]

@btime $l($opp, $a, $b)
@btime $l($w_op, $a, $b)

function h(ops, a, b)
	for op in ops
		if op == :ADD
			+(a, b)
		elseif op == :SUBS
			-(a, b)
		elseif op == :MUL
			*(a, b)
		elseif op == :DIV
			/(a, b)
		end
	end
end
@btime $h($ops_s, $a, $b)

The results:

   155.907 ns (0 allocations: 0 bytes)
   40.175 ns (0 allocations: 0 bytes)
   3.876 ns (0 allocations: 0 bytes)

I notice we do at least 2,3 more if in the case of h(…) functon calls and it is sort of scaling linearly with the type of operation calls. But we still get a fairly better speed, do we have still some speed lose during the FunctionWrapper calls?

h(...) is a compilable function, that is why it is really fast I know… but isn’t the i(...) with the wrapped version should be fine? Where do we lose speed with the FunctionWrapper version? Shouldn’t it be comparable in speed?

Yeah, some loss is unavoidable here. The problem is that a FunctionWrapper call can’t be inlined (just like a function pointer in C/C++), so it won’t be as fast as a normal function call which can be inlined. This is exactly the same problem that virtual methods have in C++.

1 Like

It’s still unclear why you need this. At the beginning of this post, you said you are trying to design a fast computation graph (I guess it’s a static one, since people use other names for a dynamic one). That’s why I said you should just compile your computation graph as a whole. Yes, you have already compiled each individual function, but you can still compile them together. For example, you have already compile f, and you want to invoke f in a loop, then you can just compile the loop. This is the right way to avoid all the overheads, as long as you can get the full computation graph.

However, the code sample you presented before is not a (static) computation graph. It’s an virtual machine (or interpreter) over an sequence of instruction. This is a totally different thing. So we need to perform runtime dispatch (since it’s impossible to get the full code). But it’s really expensive (compared to compiled version), since you have to decode opcode, set up call frame on stack and so on. We can do this by using function pointers in Julia. To avoid these costs, many virtual machines have specialized opcode for some frequently used operations, like boolean and arithmetic operation. If you try to interpret a loop, you will be faster than dynamic dispatch this way, but still lose a lots of performance. So people sometime design a bytecode compiler before interpreting codes. It will compile the code to a lower form of code (much like machine code) instead of directly using an embedded interpreter. These codes are much cheaper to interpret (but still much slower than a static compiler)

It’s possible to come up with a hybrid of static compilation and dynamic interpretation, regardless of the graph is static or dynamic. That’s what Julia’s REPL and other profiling JIT compiler are doing under the hood. It will perform interpretation for unimportant codes and compile those “hot” code. But implementing this is not that easy (require much more non-trivial Julia knowledge).

The core problem here is that how you get the “computation graph” here and how low level it is. If you can statically get the graph and control how they are generated (which I believe this is the case), you should use compiler. For example, if one is trying to design a node editor system like blueprint in blender, then he should translate the graph to Julia’s AST (or lowered IR) and compile it before run. If it’s fully dynamic (though I really suspect this), then you have to use an interpreter.

2 Likes

To be honest, from a purely semantic POV, what you’re asking about should be impossible (if there were no access to individual method instances thanks to julias’ reflection capabilities). The reason for this is pretty simple: julia is a dynamic language, meaning in theory/semantically, julia “does” a method lookup for each call. It’s just that due to the existence of type inference and static compilation of inferred code enabled by a limited eval, there really are compiled method instances to look at, even though semantically julia “does” a lookup. If julia were a static language, you wouldn’t have the luxury of just putting a generic + in a Vector and calling it (well not in any sane language with properly typed functions and not just function pointers! That’s just a disaster waiting to happen) - you would have to specify which + method you want there and basically implement the dynamic dispatch yourself.

All that’s just a cherry on top of the dynamic computation graph arguments being made here as well though.

1 Like

@rdeits thank you! What I believed is that I just call directly the dispatched function in each time, so I just save the pointer where to go and don’t waste time on dispatch. In case of + and - arithmetic the cost of the function.

@anon56330260 , yeah you are right this is eventually a virtual machine or interpreter that just fluidly does what the computational graph describe, which would be just pointer to function and their inputs to call.
The way I get the computation graph is just random at the moment. I generate from 4 operation [+, -, *, /], by time I will controll the creation of the computation graph.

@Sukera Thank you for those clarification. Yeah, I don’t know if there is a chance to just call a function like instantly due to I already have the pointer to the compiled code to call ready in the list. I feel like this is really low level but… is it possible to just move the program run pointer to the compiled function with the inputs and call… :smiley:

Do you guys think it is possible to call a function without any dynamic dispatch and just assembling function like fn_ptr + [inputs] then call…?

That’s exactly what FunctionWrapper is doing–it’s actually a function pointer under the hood, just like in C/C++. It just turns out that function pointers have a small (few nanosecond) overhead. That’s not because function pointers themselves are slow but simply because calling any function has a nonzero cost if that function is not inlined.

1 Like

As someone working in computer security, I will now insert my horrified face:

ಠ_ಠ

If you’re willing to dig very deep into julia internals and you’re not afraid of llvmcall and ccall - sure! :smiley: But at that point, why use julia in the first place? llvmcall and ccall is very powerful, but it also means you’re not going to be able to leverage a lot of things julia has to offer.

Like I mentioned, if you want to avoid dynamic dispatch at all costs - julias’ semantics are kind of going against that (the implementation of how it works is a different matter).

@rdeits FunctionWrapper.jl talks with julia internals for the user, right? What happens when those assumptions/usages break, as the internals are free to change with any julia version?

1 Like

So basically FunctionWrappers is a cfunction which introduces a compilation barrier, so for example we can’t propagate constants into the wrapped funtion from its const call arguments.

What I would also consider is using opaque closures, aren’t they pretty much the new native way of calling fixed functions? :slight_smile:
I hope they are less likely optimization barriers, but probably @Keno knows how much use it would have for this problem.

What really caught my eyes is that how could a switch case be so fast in your benchmarks ~10x faster?. :open_mouth:

1 Like

@rdeits I just realised what you are talking about, sorry. But does the inlining really yields then a 10x speedup? :o This is just nonsense… Something just strange for me. How can a function pointer jump give this big speed drawback… Damn :o

Yeah that 10x is damn crazy.

There are other people here who are far more qualified than I am on this topic, but basically: yes. When the operation is as simple as adding two integers (a single CPU instruction), then the relative cost of a function call can be large. Inlining is important…

For me it is only interesting because as I measured in my case the indirection is in case of array 0.7-0.9ns and in case of dict it is 3.5-5.5ns. I know this is something different… dict is much more complex, then how does this becomes 7ns (in case of 1 length array at the example Save dispatched function for later call - #22 by Marcell_Havlik)… it is just interesting, (the relatív difference is around 30x when more lengthy arrays are used).

Reason #1 for this is that inlining allows other optimizations. For example, if your function is +, inlining might allow loop re-ordering, which can let the compiler replace scalar addition with vectorized addition instructions, which is a few times faster.

2 Likes

In this case neither of them (reordering, vectorized addition) is applied. I could also imagine some optimisation like branch prediction and precaching the inlined code but it shouldn’t give this much difference…

I really feel like there is something here we just not enough well involved to know yet.

I haven’t looked at the compiled LLVM or native code, but it’s bound to be a mix of jump table (speedup from first to second case, since the op(a,b) call has to be determined at runtime vs. being able to use a jump table but still having to call the function) vs. inlined primitive operations (second to third case). The loop may even be completely unrolled and eliminated for the last case.

Do you know about @code_llvm and @code_native?

I was frightened to check @code_llvm with the FunctionWrappers version… :smiley: It wouldn’t tell me a lot I guess. But I will check it too!

Not sure if I misunderstand but the second can’t be inlined as it was explained by rdeits and we think this cause the drawback.

Loop can’t be unrolled in these scenarios due to it would need higher order optimization, also the compiler doesn’t compile the input into the function, so it is impossible to do that optimisation. With the latest static analyser at GCC it would be still be a hard stuff, but due to the compilation and the hardcoded input it would of course do it (but the compilation takes significant time in this scenarios). So I would close that option out too in julia.

The second one isn’t what I was talking about with inlining, I was talking about the third one.

  • The first one is type unstable in terms of what op is in each iteration, having to do a lookup in the method table for each op.
  • The second one obviously can’t be inlined because of the opaque function pointer, which led me to think it was doing a jump table instead. It still has to do the real call though, so it’s slower than the third one.
  • The third one has all possible calls directly in the function, so they can be inlined and based on op directly jump to the inlined code.

Why not? None of the invocations in the third function return anything observable outside of each iteration, the compiler would be free to only do the last iteration of the loop (eliminating it in the process). Heck, since loops don’t have a return value unlike other expressions in julia, the compiler would be free to notice that your function doesn’t compute anything once all those primitive operations are inlined and just replace it with a return nothing.

I suspect that’s what’s going on in the third version, since 3ns for 3 additions, 3 multiplications and 2 subtractions is waaaay too fast (unless they’re done at the same time in an unrolled loop, since there’s no loop dependency between iterations :upside_down_face:).

Sorry if I don’t explain more, but there is no loop unroll in this scenario.

Also I think if it would return nothing then it would be faster. A return nothing would be faster than 1ns. If we increase the operation requirement it scales linearly with the amount of operations, so I guess it really does the work btw. Why is it so fast… it is sort of 0.4ns/op which sounds close to fair to me.

This really is maddening to me, because in the code you’ve posted above, the third example does not have loop dependencies as far as I can tell :smiley: That may be different in your real code, but it isn’t in what you posted here. Why should it not be allowed to unroll the third function?

Because it isn’t doing any work:

code_llvm
julia> @code_llvm h(ops_s, 4,7)                                                                           
;  @ REPL[2]:1 within `h`                                                                                 
define void @julia_h_403({}* nonnull align 16 dereferenceable(40) %0, i64 signext %1, i64 signext %2) #0 {
top:                                                                                                      
;  @ REPL[2]:2 within `h`                                                                                 
; ┌ @ array.jl:809 within `iterate` @ array.jl:809                                                        
; │┌ @ array.jl:215 within `length`                                                                       
    %3 = bitcast {}* %0 to { i8*, i64, i16, i16, i32 }*                                                   
    %4 = getelementptr inbounds { i8*, i64, i16, i16, i32 }, { i8*, i64, i16, i16, i32 }* %3, i64 0, i32 1
    %5 = load i64, i64* %4, align 8                                                                       
; │└                                                                                                      
; │┌ @ int.jl:477 within `<` @ int.jl:470                                                                 
    %.not = icmp eq i64 %5, 0                                                                             
; │└                                                                                                      
   br i1 %.not, label %L53, label %L9                                                                     
                                                                                                          
L9:                                               ; preds = %top                                          
; │┌ @ array.jl:835 within `getindex`                                                                     
    %6 = bitcast {}* %0 to {}***                                                                          
    %7 = load {}**, {}*** %6, align 8                                                                     
    %8 = load {}*, {}** %7, align 8                                                                       
    %.not13 = icmp eq {}* %8, null                                                                        
    br i1 %.not13, label %fail, label %L18                                                                
                                                                                                          
L18:                                              ; preds = %L9                                           
; └└                                                                                                      
;  @ REPL[2]:10 within `h`                                                                                
; ┌ @ array.jl:809 within `iterate`                                                                       
; │┌ @ int.jl:477 within `<` @ int.jl:470                                                                 
    %.not1417 = icmp ugt i64 %5, 1                                                                        
; │└                                                                                                      
   br i1 %.not1417, label %L42, label %L53                                                                
                                                                                                          
L20:                                              ; preds = %L42                                          
; │┌ @ int.jl:87 within `+`                                                                               
    %9 = add nuw i64 %value_phi418, 1                                                                     
; │└                                                                                                      
; │┌ @ int.jl:477 within `<` @ int.jl:470                                                                 
    %exitcond.not = icmp eq i64 %value_phi418, %5                                                         
; │└                                                                                                      
   br i1 %exitcond.not, label %L53, label %L42                                                            
                                                                                                          
L42:                                              ; preds = %L20, %L18                                    
   %10 = phi i64 [ %value_phi418, %L20 ], [ 1, %L18 ]                                                     
   %value_phi418 = phi i64 [ %9, %L20 ], [ 2, %L18 ]                                                      
; │┌ @ array.jl:835 within `getindex`                                                                     
    %11 = getelementptr inbounds {}*, {}** %7, i64 %10                                                    
    %12 = load {}*, {}** %11, align 8                                                                     
    %.not15 = icmp eq {}* %12, null                                                                       
    br i1 %.not15, label %fail5, label %L20                                                               
                                                                                                          
L53:                                              ; preds = %L20, %L18, %top                              
; └└                                                                                                      
  ret void                                                                                                
                                                                                                          
fail:                                             ; preds = %L9                                           
;  @ REPL[2]:2 within `h`                                                                                 
; ┌ @ array.jl:809 within `iterate` @ array.jl:809                                                        
; │┌ @ array.jl:835 within `getindex`                                                                     
    call void @jl_throw({}* inttoptr (i64 140332589220768 to {}*))                                        
    unreachable                                                                                           
                                                                                                          
fail5:                                            ; preds = %L42                                          
; └└                                                                                                      
;  @ REPL[2]:10 within `h`                                                                                
; ┌ @ array.jl:809 within `iterate`                                                                       
; │┌ @ array.jl:835 within `getindex`                                                                     
    call void @jl_throw({}* inttoptr (i64 140332589220768 to {}*))                                        
    unreachable                                                                                           
; └└                                                                                                      
}                                                                                                         

L20 to L42 is the main loop body, which loads an element, compares it with null (if my LLVM reading isn’t too rusty) and then jumps right back to the start of the loop or exits.

code_native
julia> @code_native h(ops_s, 4,7)                  
        .text                                      
; ┌ @ REPL[2]:1 within `h`                         
        subq    $8, %rsp                           
; │ @ REPL[2]:2 within `h`                         
; │┌ @ array.jl:809 within `iterate` @ array.jl:809
; ││┌ @ array.jl:215 within `length`               
        movq    8(%rdi), %rax                      
; ││└                                              
; ││┌ @ int.jl:477 within `<` @ int.jl:470         
        testq   %rax, %rax                         
; ││└                                              
        je      L63                                
; ││┌ @ array.jl:835 within `getindex`             
        movq    (%rdi), %rcx                       
        cmpq    $0, (%rcx)                         
        je      L87                                
; │└└                                              
; │ @ REPL[2]:10 within `h`                        
; │┌ @ array.jl:809 within `iterate`               
; ││┌ @ int.jl:477 within `<` @ int.jl:470         
        cmpq    $2, %rax                           
; ││└                                              
        jb      L63                                
; │└                                               
; │┌ @ array.jl within `iterate`                   
        movl    $1, %edx                           
        nopw    %cs:(%rax,%rax)                    
; │└                                               
; │┌ @ array.jl:809 within `iterate`               
; ││┌ @ array.jl:835 within `getindex`             
L48:                                               
        cmpq    $0, (%rcx,%rdx,8)                  
        je      L65                                
; ││└                                              
; ││┌ @ int.jl:477 within `<` @ int.jl:470         
        incq    %rdx                               
        cmpq    %rdx, %rax                         
; ││└                                              
        jne     L48                                
; │└                                               
L63:                                               
        popq    %rax                               
        retq                                       
; │┌ @ array.jl:809 within `iterate`               
; ││┌ @ array.jl:835 within `getindex`             
L65:                                               
        movabsq $jl_throw, %rax                    
        movabsq $jl_system_image_data, %rdi        
        callq   *%rax                              
; │└└                                              
; │ @ REPL[2]:2 within `h`                         
; │┌ @ array.jl:809 within `iterate` @ array.jl:809
; ││┌ @ array.jl:835 within `getindex`             
L87:                                               
        movabsq $jl_throw, %rax                    
        movabsq $jl_system_image_data, %rdi        
        callq   *%rax                              
        nopl    (%rax)                             
; └└└                                              

Same goes for the native code, except here the loop body is called L48. Compares with 0, if it’s equal it exits, otherwise it increments a counter. Probably because the size of the input array is not fixed/known to the compiler, else it would elide that as well.

All this is possible because the result of the calculations isn’t stored anywhere. The compiler knows that the functions are sideeffect free and thus it just removes the call that quite literally doesn’t do anything.

1 Like

Yeah, sorry for not explaining the unroll thing, I didn’t have time in the midnight.

This is crazy, it just dropped the inner part of the loop, damn…

Then modifiing the whole code to be 100% percent unremovable:

using FunctionWrappers: FunctionWrapper
using BenchmarkTools

function l(ops, a, b)
	out = 0
	for op in ops
		out += op(a,b)
	end
	out
end
a=4
b=7
ops_s=[:ADD, :ADD, :MUL, :ADD, :SUBS, :SUBS, :MUL, :MUL]
opp=[+, +, *, +, -, -, *, *]
w_op=[FunctionWrapper{Int, Tuple{Int, Int}}(op) for op in opp]

@btime $l($opp, $a, $b)
@btime $l($w_op, $a, $b)

function h(ops, a, b)
	out = 0
	for op in ops
		if op == :ADD
			out += +(a, b)
		elseif op == :SUBS
			out += -(a, b)
		elseif op == :MUL
			out += *(a, b)
		elseif op == :DIV
			out += /(a, b)
		end
	end
	out
end
@btime $h($ops_s, $a, $b)

now l and h both gives 111 result. The timings:

  390.603 ns (0 allocations: 0 bytes)
  84.689 ns (0 allocations: 0 bytes)
  19.133 ns (0 allocations: 0 bytes)

I guess due to the += it does more operation with 1 or 2.

I guess then I should check what is this 4x diff now as you did with code_native and code_llvm.