Optimising function for broadcast

function f( a, b )
     c = 2a
     b + c
end

f.(   1,     1:100  )

I believe c = 2a will be calculated 100 times (every iteration of the broadcast).
Is there a way to have c evaluated once and cached ? without re-writing the whole function.
More generally, can the compiler change the way or order it calculates parts of a function based which arguments are vectorised / broadcast.

If I’m reading %code_llvm correctly, it looks like LLVM hoists the c = 2a calculation out of the loop for you in this particular case.

3 Likes

Can you show me the bit that says this? I find it very difficult to read

Here is the LLVM IR I get:

julia> @code_llvm debuginfo=:none f.(   1,     1:100  )
; Function Attrs: uwtable
define nonnull {}* @"julia_##dotfunction#315#4_257"(i64 signext %0, [2 x i64]* nocapture nonnull readonly align 8 dereferenceable(16) %1) #0 {
top:
  %2 = alloca [1 x [1 x i64]], align 8
  %3 = alloca [1 x [1 x i64]], align 8
  %4 = getelementptr inbounds [2 x i64], [2 x i64]* %1, i64 0, i64 0
  %5 = getelementptr inbounds [2 x i64], [2 x i64]* %1, i64 0, i64 1
  %6 = load i64, i64* %5, align 8
  %7 = load i64, i64* %4, align 8
  %8 = sub i64 %6, %7
  %9 = add i64 %8, 1
  %10 = icmp ult i64 %8, 9223372036854775807
  %11 = select i1 %10, i64 %9, i64 0
  %12 = getelementptr inbounds [1 x [1 x i64]], [1 x [1 x i64]]* %2, i64 0, i64 0, i64 0
  store i64 %11, i64* %12, align 8
  %13 = call nonnull {}* inttoptr (i64 1698961392 to {}* ({}*, i64)*)({}* inttoptr (i64 271088912 to {}*), i64 %11)
  %14 = bitcast {}* %13 to { i8*, i64, i16, i16, i32 }*
  %15 = getelementptr inbounds { i8*, i64, i16, i16, i32 }, { i8*, i64, i16, i16, i32 }* %14, i64 0, i32 1
  %16 = load i64, i64* %15, align 8
  %.not.not = icmp eq i64 %16, %11
  br i1 %.not.not, label %L23, label %L101

L23:                                              ; preds = %top
  %.not18 = icmp sgt i64 %11, 0
  %or.cond = select i1 %10, i1 %.not18, i1 false
  br i1 %or.cond, label %L43.lr.ph, label %L112

L43.lr.ph:                                        ; preds = %L23
  %.not17 = icmp eq i64 %6, %7
  %17 = shl i64 %0, 1
  %18 = bitcast {}* %13 to i64**
  %19 = load i64*, i64** %18, align 8
  br i1 %.not17, label %L43.lr.ph.split.us, label %L43.preheader

L43.preheader:                                    ; preds = %L43.lr.ph
  %min.iters.check = icmp ult i64 %11, 16
  br i1 %min.iters.check, label %L43, label %vector.ph

vector.ph:                                        ; preds = %L43.preheader
  %n.vec = and i64 %11, -16
  %broadcast.splatinsert = insertelement <4 x i64> poison, i64 %17, i32 0
  %broadcast.splatinsert31 = insertelement <4 x i64> poison, i64 %7, i32 0
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %vec.ind = phi <4 x i64> [ <i64 0, i64 1, i64 2, i64 3>, %vector.ph ], [ %vec.ind.next, %vector.body ]
  %20 = add <4 x i64> %broadcast.splatinsert, <i64 4, i64 poison, i64 poison, i64 poison>
  %21 = add <4 x i64> %broadcast.splatinsert, <i64 8, i64 poison, i64 poison, i64 poison>
  %22 = add <4 x i64> %broadcast.splatinsert31, %broadcast.splatinsert
  %23 = shufflevector <4 x i64> %22, <4 x i64> poison, <4 x i32> zeroinitializer
  %24 = add <4 x i64> %23, %vec.ind
  %25 = add <4 x i64> %broadcast.splatinsert31, %20
  %26 = shufflevector <4 x i64> %25, <4 x i64> poison, <4 x i32> zeroinitializer
  %27 = add <4 x i64> %26, %vec.ind
  %28 = add <4 x i64> %broadcast.splatinsert31, %21
  %29 = shufflevector <4 x i64> %28, <4 x i64> poison, <4 x i32> zeroinitializer
  %30 = add <4 x i64> %29, %vec.ind
  %31 = add <4 x i64> %22, <i64 12, i64 poison, i64 poison, i64 poison>
  %32 = shufflevector <4 x i64> %31, <4 x i64> poison, <4 x i32> zeroinitializer
  %33 = add <4 x i64> %32, %vec.ind
  %34 = getelementptr inbounds i64, i64* %19, i64 %index
  %35 = bitcast i64* %34 to <4 x i64>*
  store <4 x i64> %24, <4 x i64>* %35, align 8
  %36 = getelementptr inbounds i64, i64* %34, i64 4
  %37 = bitcast i64* %36 to <4 x i64>*
  store <4 x i64> %27, <4 x i64>* %37, align 8
  %38 = getelementptr inbounds i64, i64* %34, i64 8
  %39 = bitcast i64* %38 to <4 x i64>*
  store <4 x i64> %30, <4 x i64>* %39, align 8
  %40 = getelementptr inbounds i64, i64* %34, i64 12
  %41 = bitcast i64* %40 to <4 x i64>*
  store <4 x i64> %33, <4 x i64>* %41, align 8
  %index.next = add nuw i64 %index, 16
  %vec.ind.next = add <4 x i64> %vec.ind, <i64 16, i64 16, i64 16, i64 16>
  %42 = icmp eq i64 %index.next, %n.vec
  br i1 %42, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %cmp.n = icmp eq i64 %11, %n.vec
  br i1 %cmp.n, label %L112, label %L43

L43.lr.ph.split.us:                               ; preds = %L43.lr.ph
  %43 = add i64 %17, %6
  %min.iters.check42 = icmp ult i64 %11, 16
  br i1 %min.iters.check42, label %L43.us, label %vector.ph43

vector.ph43:                                      ; preds = %L43.lr.ph.split.us
  %n.vec45 = and i64 %11, -16
  %broadcast.splatinsert50 = insertelement <4 x i64> poison, i64 %43, i32 0
  %broadcast.splat51 = shufflevector <4 x i64> %broadcast.splatinsert50, <4 x i64> poison, <4 x i32> zeroinitializer
  br label %vector.body41

vector.body41:                                    ; preds = %vector.body41, %vector.ph43
  %index46 = phi i64 [ 0, %vector.ph43 ], [ %index.next47, %vector.body41 ]
  %44 = getelementptr inbounds i64, i64* %19, i64 %index46
  %45 = bitcast i64* %44 to <4 x i64>*
  store <4 x i64> %broadcast.splat51, <4 x i64>* %45, align 8
  %46 = getelementptr inbounds i64, i64* %44, i64 4
  %47 = bitcast i64* %46 to <4 x i64>*
  store <4 x i64> %broadcast.splat51, <4 x i64>* %47, align 8
  %48 = getelementptr inbounds i64, i64* %44, i64 8
  %49 = bitcast i64* %48 to <4 x i64>*
  store <4 x i64> %broadcast.splat51, <4 x i64>* %49, align 8
  %50 = getelementptr inbounds i64, i64* %44, i64 12
  %51 = bitcast i64* %50 to <4 x i64>*
  store <4 x i64> %broadcast.splat51, <4 x i64>* %51, align 8
  %index.next47 = add nuw i64 %index46, 16
  %52 = icmp eq i64 %index.next47, %n.vec45
  br i1 %52, label %middle.block39, label %vector.body41

middle.block39:                                   ; preds = %vector.body41
  %cmp.n49 = icmp eq i64 %11, %n.vec45
  br i1 %cmp.n49, label %L112, label %L43.us

L43.us:                                           ; preds = %L43.us, %middle.block39, %L43.lr.ph.split.us
  %value_phi119.us = phi i64 [ %54, %L43.us ], [ %n.vec45, %middle.block39 ], [ 0, %L43.lr.ph.split.us ]
  %53 = getelementptr inbounds i64, i64* %19, i64 %value_phi119.us
  store i64 %43, i64* %53, align 8
  %54 = add nuw nsw i64 %value_phi119.us, 1
  %exitcond20.not = icmp eq i64 %54, %11
  br i1 %exitcond20.not, label %L112, label %L43.us

L43:                                              ; preds = %L43, %middle.block, %L43.preheader
  %value_phi119 = phi i64 [ %58, %L43 ], [ %n.vec, %middle.block ], [ 0, %L43.preheader ]
  %55 = add i64 %value_phi119, %17
  %56 = add i64 %55, %7
  %57 = getelementptr inbounds i64, i64* %19, i64 %value_phi119
  store i64 %56, i64* %57, align 8
  %58 = add nuw nsw i64 %value_phi119, 1
  %exitcond.not = icmp eq i64 %58, %11
  br i1 %exitcond.not, label %L112, label %L43

L101:                                             ; preds = %top
  %59 = getelementptr inbounds [1 x [1 x i64]], [1 x [1 x i64]]* %3, i64 0, i64 0, i64 0
  store i64 %16, i64* %59, align 8
  %60 = call nonnull {}* @j_throwdm_262([1 x [1 x i64]]* nocapture readonly %3, [1 x [1 x i64]]* nocapture readonly %2) #0
  call void @llvm.trap()
  unreachable

L112:                                             ; preds = %L43, %L43.us, %middle.block39, %middle.block, %L23
  ret {}* %13
}

The c = 2a part is %17 = shl i64 %0, 1. It reduces the multiply by two to a left bitshift by one bit.

If we change it to c = 4a, then it becomes %17 = shl i64 %0, 2. The multiply by four shifts the bits to the left by two bits.

1 Like

Here is the x86_64 assembler:

julia> @code_native debuginfo=:none f.(1, 1:100)
        .text
        .file   "##dotfunction#326#15"
        .section        .rodata.cst32,"aM",@progbits,32
        .p2align        5                               # -- Begin function julia_##dotfunction#326#15_344
.LCPI0_0:
        .quad   0                               # 0x0
        .quad   1                               # 0x1
        .quad   2                               # 0x2
        .quad   3                               # 0x3
        .section        .rodata.cst16,"aM",@progbits,16
        .p2align        4
.LCPI0_1:
        .quad   4                               # 0x4
        .quad   4                               # 0x4
.LCPI0_2:
        .quad   8                               # 0x8
        .quad   8                               # 0x8
.LCPI0_3:
        .quad   12                              # 0xc
        .quad   12                              # 0xc
        .section        .rodata.cst8,"aM",@progbits,8
        .p2align        3
.LCPI0_4:
        .quad   16                              # 0x10
        .text
        .globl  "julia_##dotfunction#326#15_344"
        .p2align        4, 0x90
        .type   "julia_##dotfunction#326#15_344",@function
"julia_##dotfunction#326#15_344":       # @"julia_##dotfunction#326#15_344"
        .cfi_startproc
# %bb.0:                                # %top
        pushq   %rbp
        .cfi_def_cfa_offset 16
        .cfi_offset %rbp, -16
        movq    %rsp, %rbp
        .cfi_def_cfa_register %rbp
        pushq   %r15
        pushq   %r14
        pushq   %r12
        pushq   %rsi
        pushq   %rdi
        pushq   %rbx
        subq    $48, %rsp
        .cfi_offset %rbx, -64
        .cfi_offset %rdi, -56
        .cfi_offset %rsi, -48
        .cfi_offset %r12, -40
        .cfi_offset %r14, -32
        .cfi_offset %r15, -24
        movq    %rcx, %r15
        movq    (%rdx), %r12
        movq    8(%rdx), %r14
        movq    %r14, %rdi
        subq    %r12, %rdi
        leaq    1(%rdi), %rax
        xorl    %esi, %esi
        movabsq $9223372036854775807, %rcx      # imm = 0x7FFFFFFFFFFFFFFF
        cmpq    %rcx, %rdi
        cmovbq  %rax, %rsi
        movq    %rsi, -64(%rbp)
        movl    $1698961392, %eax               # imm = 0x654417F0
        movl    $271088912, %ecx                # imm = 0x10287D10
        movq    %rsi, %rdx
        callq   *%rax
        movq    8(%rax), %rcx
        cmpq    %rsi, %rcx
        jne     .LBB0_18
# %bb.1:                                # %L23
        movabsq $9223372036854775806, %rcx      # imm = 0x7FFFFFFFFFFFFFFE
        cmpq    %rcx, %rdi
        ja      .LBB0_17
# %bb.2:                                # %L23
        testq   %rsi, %rsi
        jle     .LBB0_17
# %bb.3:                                # %L43.lr.ph
        addq    %r15, %r15
        movq    (%rax), %rcx
        cmpq    %r12, %r14
        jne     .LBB0_4
# %bb.11:                               # %L43.lr.ph.split.us
        addq    %r14, %r15
        cmpq    $16, %rsi
        jae     .LBB0_13
# %bb.12:
        xorl    %edx, %edx
        jmp     .LBB0_16
.LBB0_4:                                # %L43.preheader
        cmpq    $16, %rsi
        jae     .LBB0_6
# %bb.5:
        xorl    %edx, %edx
        jmp     .LBB0_9
.LBB0_13:                               # %vector.ph43
        movq    %rsi, %rdx
        andq    $-16, %rdx
        vpbroadcastq    %r15, %ymm0
        xorl    %ebx, %ebx
        .p2align        4, 0x90
.LBB0_14:                               # %vector.body41
                                        # =>This Inner Loop Header: Depth=1
        vmovdqu %ymm0, (%rcx,%rbx,8)
        vmovdqu %ymm0, 32(%rcx,%rbx,8)
        vmovdqu %ymm0, 64(%rcx,%rbx,8)
        vmovdqu %ymm0, 96(%rcx,%rbx,8)
        addq    $16, %rbx
        cmpq    %rbx, %rdx
        jne     .LBB0_14
# %bb.15:                               # %middle.block39
        cmpq    %rdx, %rsi
        je      .LBB0_17
        .p2align        4, 0x90
.LBB0_16:                               # %L43.us
                                        # =>This Inner Loop Header: Depth=1
        movq    %r15, (%rcx,%rdx,8)
        incq    %rdx
        cmpq    %rdx, %rsi
        jne     .LBB0_16
        jmp     .LBB0_17
.LBB0_6:                                # %vector.ph
        movq    %rsi, %rdx
        andq    $-16, %rdx
        vmovq   %r15, %xmm1
        vmovq   %r12, %xmm2
        movabsq $.LCPI0_0, %rdi
        vmovdqa (%rdi), %ymm0
        vpaddq  %ymm1, %ymm2, %ymm4
        xorl    %edi, %edi
        vpbroadcastq    %xmm4, %ymm1
        movabsq $.LCPI0_1, %rbx
        vpaddq  (%rbx), %xmm4, %xmm2
        vpbroadcastq    %xmm2, %ymm2
        movabsq $.LCPI0_2, %rbx
        vpaddq  (%rbx), %xmm4, %xmm3
        vpbroadcastq    %xmm3, %ymm3
        movabsq $.LCPI0_3, %rbx
        vpaddq  (%rbx), %xmm4, %xmm4
        vpbroadcastq    %xmm4, %ymm4
        movabsq $.LCPI0_4, %rbx
        vpbroadcastq    (%rbx), %ymm5
        .p2align        4, 0x90
.LBB0_7:                                # %vector.body
                                        # =>This Inner Loop Header: Depth=1
        vpaddq  %ymm0, %ymm1, %ymm16
        vpaddq  %ymm0, %ymm2, %ymm17
        vpaddq  %ymm0, %ymm3, %ymm18
        vmovdqu64       %ymm16, (%rcx,%rdi,8)
        vmovdqu64       %ymm17, 32(%rcx,%rdi,8)
        vmovdqu64       %ymm18, 64(%rcx,%rdi,8)
        vpaddq  %ymm0, %ymm4, %ymm16
        vmovdqu64       %ymm16, 96(%rcx,%rdi,8)
        addq    $16, %rdi
        vpaddq  %ymm5, %ymm0, %ymm0
        cmpq    %rdi, %rdx
        jne     .LBB0_7
# %bb.8:                                # %middle.block
        cmpq    %rdx, %rsi
        je      .LBB0_17
.LBB0_9:                                # %L43.preheader58
        addq    %r15, %r12
        .p2align        4, 0x90
.LBB0_10:                               # %L43
                                        # =>This Inner Loop Header: Depth=1
        leaq    (%r12,%rdx), %rbx
        movq    %rbx, (%rcx,%rdx,8)
        incq    %rdx
        cmpq    %rdx, %rsi
        jne     .LBB0_10
.LBB0_17:                               # %L112
        addq    $48, %rsp
        popq    %rbx
        popq    %rdi
        popq    %rsi
        popq    %r12
        popq    %r14
        popq    %r15
        popq    %rbp
        vzeroupper
        retq
.LBB0_18:                               # %L101
        movq    %rcx, -56(%rbp)
        movabsq $j_throwdm_349, %rax
        leaq    -56(%rbp), %rcx
        leaq    -64(%rbp), %rdx
        callq   *%rax
        ud2
.Lfunc_end0:
        .size   "julia_##dotfunction#326#15_344", .Lfunc_end0-"julia_##dotfunction#326#15_344"
        .cfi_endproc
                                        # -- End function
        .section        ".note.GNU-stack","",@progbits

The first argument a is loaded from %rcx into %r15 and then is added to itself.

        movq    %rcx, %r15
        ...
        addq    %r15, %r15
1 Like

Thanks guys.
Here is a better example. value_fast.() is 20 times faster than value_basic.()
I presume this is because interp_fast does the first 6 lines once per deal. So the only part that is run per deal and element of ys is the final line of the interp w1*y[i1] + w2*y[i2] whereas value_basic runs linear_interpolation for every deal and element of ys.

using Interpolations

x = [ [0 1 2 7 14]./365   [1 2 3 6 9]./12    1 2 3 4 5 7  (10:5:50)'] |> vec
ys =  rand(length(x),1000)  |> eachcol |> collect

function interp_fast( T,  x,  y )
    i1 = searchsortedlast(x,T)
    i2 = searchsortedfirst(x,T)
    x1 = x[i1]
    x2 = x[i2]  
    w2 = (T - x1)/(x2-x1)
    w1 = 1-w2
    w1*y[i1] + w2*y[i2]
end 

struct Deal
    N::Float64
    T::Float64
end

n       = 10000
deals   = Deal.( 1000rand(n)', 50rand(n)' )

value_basic(d::Deal, x, y )  =  d.N * linear_interpolation(x,y)(d.T)
value_fast( d::Deal, x, y )  =  d.N * interp_fast(d.T,x,y)

value_basic.( deals, (x,), ys)      
value_fast.(  deals, (x,), ys)

Actually no. I think linear_interpolation is just a slower function. Probably does a lot of input checking.

In any case, at best LLVM will only be able to hoist simple calculations (which it can prove are pure) out of loops. If you want to cache more complicated calculations in general, your best bet is to refactor into a lower-level version of the function that separates those calculations into arguments. e.g. instead of:

function foo(x, y)
     z = expensive(x)
     return bar(x,y,z)
end
foo.(xscalar, yarray)

do

zscalar = expensive(xscalar)
bar.(xscalar, yarray, zscalar)

More generally, you will get the greatest possible flexibility if you write your own loops. Loops are fast in Julia.

Note that linear_interpolation is also trying to handle extrapolation and is setting up to do multiple interpolations.

help?> linear_interpolation
search: linear_interpolation

  etp = linear_interpolation(knots, A; extrapolation_bc=Throw())

  A shorthand for one of the following.

    •  extrapolate(scale(interpolate(A, BSpline(Linear())), knots), extrapolation_bc)

    •  extrapolate(interpolate(knots, A, Gridded(Linear())), extrapolation_bc),

  depending on whether knots are ranges or vectors.

  Consider using interpolate, scale, or extrapolate explicitly as needed rather than using this convenience
  constructor. Performance will improve without scaling or extrapolation.

To review basic usage for a single interpolator, I would do

julia> itp = interpolate((x,), ys[1], Gridded(Linear()));

julia> deals[1].N * itp(deals[1].T)
593.9842288550919

In your case, you only need to create 1,000 interpolators, but you are creating 10,000,000 interpolators.

julia> value_faster(d::Deal, itp) = d.N * itp(d.T)
value_faster (generic function with 2 methods)

julia> itps = [interpolate((x,), y, Gridded(Linear())) for y in ys];

julia> @time value_faster.(deals, itps);
  0.253503 seconds (163.40 k allocations: 84.118 MiB, 4.32% gc time, 21.24% compilation time: 100% of which was recompilation)

julia> @time value_faster.(deals, itps);
  0.200224 seconds (4 allocations: 76.294 MiB)

julia> @time value_fast.(deals, (x,), ys);
  0.249927 seconds (5 allocations: 76.294 MiB)

julia> @btime value_faster.(deals, itps);
  168.242 ms (4 allocations: 76.29 MiB)

julia> @btime value_fast.(deals, (x,), ys);
  220.430 ms (5 allocations: 76.29 MiB)

julia> value_fast.(deals, (x,), ys) == value_faster.(deals, itps)
true

You could argue that’s not quite a fair comparison since I made the itps earlier. Here is the combined function and benchmark.

julia> function value_faster_broadcasted(deals, x, ys)
           itps = [interpolate((x,), y, Gridded(Linear())) for y in ys]
           value_faster.(deals, itps)
       end
value_faster_broadcasted (generic function with 1 method)

julia> @btime value_faster_broadcasted(deals, x, ys);
  170.541 ms (8003 allocations: 77.65 MiB)

Also, I’m the maintainer of Interpolations.jl. If you have any complaints or questions, please let me know.

Thanks Mark. Yes that is faster.
In reality I have many deal types, each with its own valuation function.
Under your method I’d need a pre_value() and value() for each deal type.
Maybe this is the only way to ensure efficiency.
But maybe as the compiler is continually improved, it will be able to reliably extract function elements from the broadcast loop.

Thanks for the heads up re Interpolations.jl. I had been using Dierckx but Interpolations has 3 times as many stars so have switched.