2D nested loop gives big performance hit in interpolation routine while it should not

I have written a macro that generates fast trilinear interpolation functions for 2x2x2 upsampling with threading. The code runs fine and is on average an order of magnitude faster than Interpolations.jl (which I probably not use in its fastest settings), but I notice an enormous penalty (~50%) if I take into account one additional 2-dimensional loop that takes care of the top boundary (boolean switch, last argument in @upsample_lin_fast("222", 2, 2, 2, 0, 0, 0.5, true) and @btime shows a large allocation overhead, that is not there if I only run the 3D part (switch to false). The generated code is rather straightforward and is printed to the screen (and probably much easier to look at than the macro). What am I overlooking, because logically the additional computations should be very small compare to the total computational cost. You can test the performance difference by switching between true and false in the statement above.

## Load the packages.
using LoopVectorization
using Interpolations
using Statistics
using BenchmarkTools
using PyPlot
using Printf


## Define the upsample functions.
function make_inner_loop!(ex_list, ii, jj, kk, ifac, jfac, kfac, ioff, joff, koff, is_top)
    push!(ex_list, :(i_hi = $ifac*i+$ii+is_hi), :(j_hi = $jfac*j+$jj+js_hi))
    if is_top
        push!(ex_list, :(k_hi = ke_hi))
    else
        push!(ex_list, :(k_hi = $kfac*k+$kk+ks_hi))
    end

    i_pos = -1/2 + ioff + (1/2 - ioff + ii)/ifac
    j_pos = -1/2 + joff + (1/2 - joff + jj)/jfac
    k_pos = -1/2 + koff + (1/2 - koff + kk)/kfac

    i_lo_off = floor(Int, i_pos) + is_lo
    j_lo_off = floor(Int, j_pos) + js_lo
    k_lo_off = floor(Int, k_pos) + ks_lo

    if is_top
        push!(ex_list, :(i_lo = i + $i_lo_off), :(j_lo = j + $j_lo_off), :(k_lo = ke_lo))
    else
        push!(ex_list, :(i_lo = i + $i_lo_off), :(j_lo = j + $j_lo_off), :(k_lo = k + $k_lo_off))
    end

    fi = mod(i_pos, 1)
    fj = mod(j_pos, 1)
    fk = mod(k_pos, 1)

    coef = zeros(2, 2, 2)
    coef[1, 1, 1] = (((1-fi)*fj+fi-1)*fk+(fi-1)*fj-fi+1)
    coef[2, 1, 1] = ((fi*fj-fi)*fk-fi*fj+fi)
    coef[1, 2, 1] = ((fi-1)*fj*fk+(1-fi)*fj)
    coef[2, 2, 1] = (fi*fj-fi*fj*fk)
    coef[1, 1, 2] = ((fi-1)*fj-fi+1)*fk
    coef[2, 1, 2] = (fi-fi*fj)*fk
    coef[1, 2, 2] = (1-fi)*fj*fk
    coef[2, 2, 2] = fi*fj*fk

    kc_end = is_top ? 1 : 2

    rhs_list = []
    for kc in 1:kc_end, jc in 1:2, ic in 1:2
        if !(coef[ic, jc, kc] ≈ 0)
            push!(rhs_list, :( $(coef[ic, jc, kc])*lo[i_lo+$(ic-1), j_lo+$(jc-1), k_lo+$(kc-1)] ))
        end
    end

    rhs = Expr(:call, :+, rhs_list...)

    push!(ex_list, :(hi[i_hi, j_hi, k_hi] = $rhs))
end

macro upsample_lin_fast(suffix, ifac, jfac, kfac, ioff, joff, koff, add_top)
    ex_inner = []
    for kk in 0:kfac-1, jj in 0:jfac-1, ii in 0:ifac-1
        make_inner_loop!(ex_inner, ii, jj, kk, ifac, jfac, kfac, ioff, joff, koff, false)
    end
    ex_inner_block = quote
        Threads.@threads for k in 0:ktot-1
            for j in 0:jtot-1
                @inbounds @simd for i in 0:itot-1
                    $(Expr(:block, ex_inner...))
                end
            end
        end
    end

    if add_top
        ex_top = []
        kk = 0
        for jj in 0:jfac-1, ii in 0:ifac-1
            make_inner_loop!(ex_top, ii, jj, kk, ifac, jfac, kfac, ioff, joff, koff, true)
        end
        ex_top_block = quote
            Threads.@threads for j in 0:jtot-1
                @inbounds @simd for i in 0:itot-1
                    $(Expr(:block, ex_top...))
                end
            end
        end
    else
        ex_top_block = :()
    end

    name = Symbol(@sprintf "upsample_lin_%s!" suffix)
    ex = quote
        function $name(hi, lo, itot, jtot, ktot, is_hi, js_hi, ks_hi, is_lo, js_lo, ks_lo)
            $ex_inner_block
            $ex_top_block
        end
    end
    println(ex)
    return esc(ex)
end

function upsample_lin_ref!(hi, lo, x_hi, y_hi, z_hi, x_lo, y_lo, z_lo, is_hi, js_hi, ks_hi, ie_hi, je_hi, ke_hi)
    interp = interpolate((x_lo, y_lo, z_lo), lo, (Gridded(Linear()), Gridded(Linear()), Gridded(Linear())))
    hi[is_hi:ie_hi, js_hi:je_hi, ks_hi:ke_hi] .= interp(x_hi[is_hi:ie_hi], y_hi[js_hi:je_hi], z_hi[ks_hi:ke_hi])
end


## Set up the grids.
itot_lo = 256; jtot_lo = 192; ktot_lo = 128
itot_hi = 512; jtot_hi = 384; ktot_hi = 256
igc_lo = 1; jgc_lo = 1; kgc_lo = 1
igc_hi = 1; jgc_hi = 1; kgc_hi = 1

# Derive dimensions
ifac = itot_hi÷itot_lo; jfac = jtot_hi÷jtot_lo; kfac = ktot_hi÷ktot_lo;
icells_lo = itot_lo + 2igc_lo; jcells_lo = jtot_lo + 2jgc_lo; kcells_lo = ktot_lo + 2kgc_lo
icells_hi = itot_hi + 2igc_hi; jcells_hi = jtot_hi + 2jgc_hi; kcells_hi = ktot_hi + 2kgc_hi
is_lo = igc_lo+1; js_lo = jgc_lo+1; ks_lo = kgc_lo+1
is_hi = igc_hi+1; js_hi = jgc_hi+1; ks_hi = kgc_hi+1
ie_lo = igc_lo+itot_lo; je_lo = jgc_lo+jtot_lo; ke_lo = kgc_lo+ktot_lo+1
ie_hi = igc_hi+itot_hi; je_hi = jgc_hi+jtot_hi; ke_hi = kgc_hi+ktot_hi+1

w_lo = rand(icells_lo, jcells_lo, kcells_lo)
w_hi_ref = zeros(icells_hi, jcells_hi, kcells_hi)
w_hi_fast = zeros(icells_hi, jcells_hi, kcells_hi)

dx_lo = 1/itot_lo; dy_lo = 1/jtot_lo; dz_lo = 1/ktot_lo
x_lo = collect(range(0.5-igc_lo, length=icells_lo)) .* dx_lo
y_lo = collect(range(0.5-jgc_lo, length=jcells_lo)) .* dy_lo
z_lo = collect(range(0.5-kgc_lo, length=kcells_lo)) .* dz_lo
xh_lo = collect(range(-igc_lo, length=icells_lo)) .* dx_lo
yh_lo = collect(range(-jgc_lo, length=jcells_lo)) .* dy_lo
zh_lo = collect(range(-kgc_lo, length=kcells_lo)) .* dz_lo

dx_hi = 1/itot_hi; dy_hi = 1/jtot_hi; dz_hi = 1/ktot_hi
x_hi = collect(range(0.5-igc_hi, length=icells_hi)) .* dx_hi
y_hi = collect(range(0.5-jgc_hi, length=jcells_hi)) .* dy_hi
z_hi = collect(range(0.5-kgc_hi, length=kcells_hi)) .* dz_hi
xh_hi = collect(range(-igc_hi, length=icells_hi)) .* dx_hi
yh_hi = collect(range(-jgc_hi, length=jcells_hi)) .* dy_hi
zh_hi = collect(range(-kgc_hi, length=kcells_hi)) .* dz_hi


## Compute a reference using Interpolations.jl
@upsample_lin_fast("222", 2, 2, 2, 0, 0, 0.5, true)
@btime upsample_lin_ref!(
    w_hi_ref, w_lo, x_hi, y_hi, zh_hi, x_lo, y_lo, zh_lo,
    is_hi, js_hi, ks_hi, ie_hi, je_hi, ke_hi)
@btime upsample_lin_222!(
    w_hi_fast, w_lo, itot_lo, jtot_lo, ktot_lo,
    is_hi, js_hi, ks_hi, is_lo, js_lo, ks_lo)

w_lo_int = @view w_lo[is_lo:ie_lo, js_lo:je_lo, ks_lo:ke_lo]
w_hi_ref_int = @view w_hi_ref[is_hi:ie_hi, js_hi:je_hi, ks_hi:ke_hi]
w_hi_fast_int = @view w_hi_fast[is_hi:ie_hi, js_hi:je_hi, ks_hi:ke_hi]

println("Fast equal to ref: ", w_hi_fast_int ≈ w_hi_ref_int)


## Plot the output.
zh_lo_int = @view zh_lo[ks_lo:ke_lo]
zh_hi_int = @view zh_hi[ks_hi:ke_hi]

figure()
plot(zh_hi_int, w_hi_ref_int[1, 1, :], "C1-^")
plot(zh_hi_int, w_hi_fast_int[1, 1, :], "C2-*")
plot(zh_lo_int, w_lo_int[1, 1, :], "k:+")
tight_layout()
display(gcf())
show()

@code_warntype shows

@code_warntype upsample_lin_222!(

           w_hi_fast, w_lo, itot_lo, jtot_lo, ktot_lo,

           is_hi, js_hi, ks_hi, is_lo, js_lo, ks_lo)
Variables
  #self#::Core.Const(upsample_lin_222!)
  hi::Array{Float64, 3}
  lo::Array{Float64, 3}
  itot::Int64
  jtot::Int64
  ktot::Int64
  is_hi::Int64
  js_hi::Int64
  ks_hi::Int64
  is_lo::Int64
  js_lo::Int64
  ks_lo::Int64
  @_13::Union{Nothing, Tuple{Int64, Int64}}
  threadsfor_fun::var"#47#threadsfor_fun#7"{Array{Float64, 3}, Array{Float64, 3}, Int64, Int64, Int64, Int64, Int64, UnitRange{Int64}}
  range::UnitRange{Int64}
  @_16::Union{Nothing, Tuple{Int64, Int64}}
  j::Int64
  i::Int64
  k_lo::Any
  j_lo::Int64
  i_lo::Int64
  k_hi::Any
  j_hi::Int64
  i_hi::Int64
[...]

that k_lo and k_hi are of type Any. Maybe this results from the line

        push!(ex_list, :(i_lo = i + $i_lo_off), :(j_lo = j + $j_lo_off), :(k_lo = ke_lo))

and the fact that ke_lo is a global variable?

2 Likes

Indeed, thats it, performance is now nearly identical between true and false. I forgot to pass the ke_ variables as arguments as I adjusted the functions later. Thanks a lot, I wasn’t aware of using code_warntype in this way. Thanks a lot!