# 2D nested loop gives big performance hit in interpolation routine while it should not

I have written a macro that generates fast trilinear interpolation functions for 2x2x2 upsampling with threading. The code runs fine and is on average an order of magnitude faster than `Interpolations.jl` (which I probably not use in its fastest settings), but I notice an enormous penalty (~50%) if I take into account one additional 2-dimensional loop that takes care of the top boundary (boolean switch, last argument in `@upsample_lin_fast("222", 2, 2, 2, 0, 0, 0.5, true)` and `@btime` shows a large allocation overhead, that is not there if I only run the 3D part (switch to `false`). The generated code is rather straightforward and is printed to the screen (and probably much easier to look at than the macro). What am I overlooking, because logically the additional computations should be very small compare to the total computational cost. You can test the performance difference by switching between `true` and `false` in the statement above.

``````## Load the packages.
using LoopVectorization
using Interpolations
using Statistics
using BenchmarkTools
using PyPlot
using Printf

## Define the upsample functions.
function make_inner_loop!(ex_list, ii, jj, kk, ifac, jfac, kfac, ioff, joff, koff, is_top)
push!(ex_list, :(i_hi = \$ifac*i+\$ii+is_hi), :(j_hi = \$jfac*j+\$jj+js_hi))
if is_top
push!(ex_list, :(k_hi = ke_hi))
else
push!(ex_list, :(k_hi = \$kfac*k+\$kk+ks_hi))
end

i_pos = -1/2 + ioff + (1/2 - ioff + ii)/ifac
j_pos = -1/2 + joff + (1/2 - joff + jj)/jfac
k_pos = -1/2 + koff + (1/2 - koff + kk)/kfac

i_lo_off = floor(Int, i_pos) + is_lo
j_lo_off = floor(Int, j_pos) + js_lo
k_lo_off = floor(Int, k_pos) + ks_lo

if is_top
push!(ex_list, :(i_lo = i + \$i_lo_off), :(j_lo = j + \$j_lo_off), :(k_lo = ke_lo))
else
push!(ex_list, :(i_lo = i + \$i_lo_off), :(j_lo = j + \$j_lo_off), :(k_lo = k + \$k_lo_off))
end

fi = mod(i_pos, 1)
fj = mod(j_pos, 1)
fk = mod(k_pos, 1)

coef = zeros(2, 2, 2)
coef[1, 1, 1] = (((1-fi)*fj+fi-1)*fk+(fi-1)*fj-fi+1)
coef[2, 1, 1] = ((fi*fj-fi)*fk-fi*fj+fi)
coef[1, 2, 1] = ((fi-1)*fj*fk+(1-fi)*fj)
coef[2, 2, 1] = (fi*fj-fi*fj*fk)
coef[1, 1, 2] = ((fi-1)*fj-fi+1)*fk
coef[2, 1, 2] = (fi-fi*fj)*fk
coef[1, 2, 2] = (1-fi)*fj*fk
coef[2, 2, 2] = fi*fj*fk

kc_end = is_top ? 1 : 2

rhs_list = []
for kc in 1:kc_end, jc in 1:2, ic in 1:2
if !(coef[ic, jc, kc] ≈ 0)
push!(rhs_list, :( \$(coef[ic, jc, kc])*lo[i_lo+\$(ic-1), j_lo+\$(jc-1), k_lo+\$(kc-1)] ))
end
end

rhs = Expr(:call, :+, rhs_list...)

push!(ex_list, :(hi[i_hi, j_hi, k_hi] = \$rhs))
end

macro upsample_lin_fast(suffix, ifac, jfac, kfac, ioff, joff, koff, add_top)
ex_inner = []
for kk in 0:kfac-1, jj in 0:jfac-1, ii in 0:ifac-1
make_inner_loop!(ex_inner, ii, jj, kk, ifac, jfac, kfac, ioff, joff, koff, false)
end
ex_inner_block = quote
for j in 0:jtot-1
@inbounds @simd for i in 0:itot-1
\$(Expr(:block, ex_inner...))
end
end
end
end

ex_top = []
kk = 0
for jj in 0:jfac-1, ii in 0:ifac-1
make_inner_loop!(ex_top, ii, jj, kk, ifac, jfac, kfac, ioff, joff, koff, true)
end
ex_top_block = quote
@inbounds @simd for i in 0:itot-1
\$(Expr(:block, ex_top...))
end
end
end
else
ex_top_block = :()
end

name = Symbol(@sprintf "upsample_lin_%s!" suffix)
ex = quote
function \$name(hi, lo, itot, jtot, ktot, is_hi, js_hi, ks_hi, is_lo, js_lo, ks_lo)
\$ex_inner_block
\$ex_top_block
end
end
println(ex)
return esc(ex)
end

function upsample_lin_ref!(hi, lo, x_hi, y_hi, z_hi, x_lo, y_lo, z_lo, is_hi, js_hi, ks_hi, ie_hi, je_hi, ke_hi)
interp = interpolate((x_lo, y_lo, z_lo), lo, (Gridded(Linear()), Gridded(Linear()), Gridded(Linear())))
hi[is_hi:ie_hi, js_hi:je_hi, ks_hi:ke_hi] .= interp(x_hi[is_hi:ie_hi], y_hi[js_hi:je_hi], z_hi[ks_hi:ke_hi])
end

## Set up the grids.
itot_lo = 256; jtot_lo = 192; ktot_lo = 128
itot_hi = 512; jtot_hi = 384; ktot_hi = 256
igc_lo = 1; jgc_lo = 1; kgc_lo = 1
igc_hi = 1; jgc_hi = 1; kgc_hi = 1

# Derive dimensions
ifac = itot_hi÷itot_lo; jfac = jtot_hi÷jtot_lo; kfac = ktot_hi÷ktot_lo;
icells_lo = itot_lo + 2igc_lo; jcells_lo = jtot_lo + 2jgc_lo; kcells_lo = ktot_lo + 2kgc_lo
icells_hi = itot_hi + 2igc_hi; jcells_hi = jtot_hi + 2jgc_hi; kcells_hi = ktot_hi + 2kgc_hi
is_lo = igc_lo+1; js_lo = jgc_lo+1; ks_lo = kgc_lo+1
is_hi = igc_hi+1; js_hi = jgc_hi+1; ks_hi = kgc_hi+1
ie_lo = igc_lo+itot_lo; je_lo = jgc_lo+jtot_lo; ke_lo = kgc_lo+ktot_lo+1
ie_hi = igc_hi+itot_hi; je_hi = jgc_hi+jtot_hi; ke_hi = kgc_hi+ktot_hi+1

w_lo = rand(icells_lo, jcells_lo, kcells_lo)
w_hi_ref = zeros(icells_hi, jcells_hi, kcells_hi)
w_hi_fast = zeros(icells_hi, jcells_hi, kcells_hi)

dx_lo = 1/itot_lo; dy_lo = 1/jtot_lo; dz_lo = 1/ktot_lo
x_lo = collect(range(0.5-igc_lo, length=icells_lo)) .* dx_lo
y_lo = collect(range(0.5-jgc_lo, length=jcells_lo)) .* dy_lo
z_lo = collect(range(0.5-kgc_lo, length=kcells_lo)) .* dz_lo
xh_lo = collect(range(-igc_lo, length=icells_lo)) .* dx_lo
yh_lo = collect(range(-jgc_lo, length=jcells_lo)) .* dy_lo
zh_lo = collect(range(-kgc_lo, length=kcells_lo)) .* dz_lo

dx_hi = 1/itot_hi; dy_hi = 1/jtot_hi; dz_hi = 1/ktot_hi
x_hi = collect(range(0.5-igc_hi, length=icells_hi)) .* dx_hi
y_hi = collect(range(0.5-jgc_hi, length=jcells_hi)) .* dy_hi
z_hi = collect(range(0.5-kgc_hi, length=kcells_hi)) .* dz_hi
xh_hi = collect(range(-igc_hi, length=icells_hi)) .* dx_hi
yh_hi = collect(range(-jgc_hi, length=jcells_hi)) .* dy_hi
zh_hi = collect(range(-kgc_hi, length=kcells_hi)) .* dz_hi

## Compute a reference using Interpolations.jl
@upsample_lin_fast("222", 2, 2, 2, 0, 0, 0.5, true)
@btime upsample_lin_ref!(
w_hi_ref, w_lo, x_hi, y_hi, zh_hi, x_lo, y_lo, zh_lo,
is_hi, js_hi, ks_hi, ie_hi, je_hi, ke_hi)
@btime upsample_lin_222!(
w_hi_fast, w_lo, itot_lo, jtot_lo, ktot_lo,
is_hi, js_hi, ks_hi, is_lo, js_lo, ks_lo)

w_lo_int = @view w_lo[is_lo:ie_lo, js_lo:je_lo, ks_lo:ke_lo]
w_hi_ref_int = @view w_hi_ref[is_hi:ie_hi, js_hi:je_hi, ks_hi:ke_hi]
w_hi_fast_int = @view w_hi_fast[is_hi:ie_hi, js_hi:je_hi, ks_hi:ke_hi]

println("Fast equal to ref: ", w_hi_fast_int ≈ w_hi_ref_int)

## Plot the output.
zh_lo_int = @view zh_lo[ks_lo:ke_lo]
zh_hi_int = @view zh_hi[ks_hi:ke_hi]

figure()
plot(zh_hi_int, w_hi_ref_int[1, 1, :], "C1-^")
plot(zh_hi_int, w_hi_fast_int[1, 1, :], "C2-*")
plot(zh_lo_int, w_lo_int[1, 1, :], "k:+")
tight_layout()
display(gcf())
show()
``````

`@code_warntype` shows

``````@code_warntype upsample_lin_222!(

w_hi_fast, w_lo, itot_lo, jtot_lo, ktot_lo,

is_hi, js_hi, ks_hi, is_lo, js_lo, ks_lo)
Variables
#self#::Core.Const(upsample_lin_222!)
hi::Array{Float64, 3}
lo::Array{Float64, 3}
itot::Int64
jtot::Int64
ktot::Int64
is_hi::Int64
js_hi::Int64
ks_hi::Int64
is_lo::Int64
js_lo::Int64
ks_lo::Int64
@_13::Union{Nothing, Tuple{Int64, Int64}}
range::UnitRange{Int64}
@_16::Union{Nothing, Tuple{Int64, Int64}}
j::Int64
i::Int64
k_lo::Any
j_lo::Int64
i_lo::Int64
k_hi::Any
j_hi::Int64
i_hi::Int64
[...]
``````

that `k_lo` and `k_hi` are of type `Any`. Maybe this results from the line

``````        push!(ex_list, :(i_lo = i + \$i_lo_off), :(j_lo = j + \$j_lo_off), :(k_lo = ke_lo))
``````

and the fact that `ke_lo` is a global variable?

2 Likes

Indeed, thats it, performance is now nearly identical between `true` and `false`. I forgot to pass the `ke_` variables as arguments as I adjusted the functions later. Thanks a lot, I wasn’t aware of using `code_warntype` in this way. Thanks a lot!