Searchsortedlast performance

I am trying to write a Julia port of the Weno4Interpolation python package, an algorithm based on Janett et al. (2019) to perform higher order interpolations.

The code itself is not complex, but has many statements, mostly simple algebraic expressions. This is the type of problem where I expected Julia would shine, but I was surprised to find that the Python version (using numba) consistently outperforms Julia, finishing in about 2/3 of the time in my system. I have some implementation in Julia and numba that is nearly identical line by line.

Since it is so simple, I’m a bit at odds on how to optimise it. I’ve tried the usual tricks: additional type annotations, @inbounds, etc, but just can’t further optimise the julia version. Any suggestions on how this code could be improved?

This is the Julia version:

function weno4_impl(xs, xp, fp)
    Ngrid = size(xp)[1]
    ε = 1e-6
    fs = zeros(typeof(xs[1]), size(xs))
    left = fp[1]
    right = fp[end]
    prevβidx = -1
    β2 = 0.0
    β3 = 0.0
    for (idx, x) in enumerate(xs)
        i = searchsortedlast(xp, x)
        if x < xp[1]
            fs[idx] = left
            continue
        elseif x > xp[end]
            fs[idx] = right
        end

        if i == Ngrid
            i -= 1
        end

        if i == 1
            xi = xp[i]
            xip = xp[i+1]
            xipp = xp[i+2]

            hi = xip - xi
            hip = xipp - xip

            yi = fp[i]
            yip = fp[i+1]
            yipp = fp[i+2]

            q3 = yi * ((x - xip) * (x - xipp)) / (hi * (hi + hip))
            q3 -= yip * ((x - xi) * (x - xipp)) / (hi * hip)
            q3 += yipp * ((x - xi) * (x - xip)) / ((hi + hip) * hip)

            fs[idx] = q3
            continue
        elseif i == Ngrid - 1
            xim = xp[i-1]
            xi = xp[i]
            xip = xp[i+1]

            him = xi - xim
            hi = xip - xi

            yim = fp[i-1]
            yi = fp[i]
            yip = fp[i+1]

            q2 = yim * ((x - xi) * (x - xip)) / (him * (him + hi))
            q2 -= yi * ((x - xim) * (x - xip)) / (him * hi)
            q2 += yip * ((x - xim) * (x - xi)) / ((him + hi) * hi)

            fs[idx] = q2
            continue
        end

        xim = xp[i-1]
        xi = xp[i]
        xip = xp[i+1]
        xipp = xp[i+2]

        him = xi - xim
        hi = xip - xi
        hip = xipp - xip

        yim = fp[i-1]
        yi = fp[i]
        yip = fp[i+1]
        yipp = fp[i+2]

        q2 = yim * ((x - xi) * (x - xip)) / (him * (him + hi))
        q2 -= yi * ((x - xim) * (x - xip)) / (him * hi)
        q2 += yip * ((x - xim) * (x - xi)) / ((him + hi) * hi)

        q3 = yi * ((x - xip) * (x - xipp)) / (hi * (hi + hip))
        q3 -= yip * ((x - xi) * (x - xipp)) / (hi * hip)
        q3 += yipp * ((x - xi) * (x - xip)) / ((hi + hip) * hip)

        if i != prevβidx
            H = him + hi + hip
            yyim = - ((2*him + hi)*H + him*(him + hi)) / (him*(him + hi)*H) * yim
            yyim += ((him + hi)*H) / (him*hi*(hi + hip)) * yi
            yyim -= (him*H) / ((him + hi)*hi*hip) * yip
            yyim += (him*(him + hi)) / ((hi + hip)*hip*H) * yipp

            yyi = - (hi*(hi + hip)) / (him*(him + hi)*H) * yim
            yyi += (hi*(hi + hip) - him*(2*hi + hip)) / (him*hi*(hi + hip)) * yi
            yyi += (him*(hi + hip)) / ((him + hi)*hi*hip) * yip
            yyi -= (him*hi) / ((hi + hip)*hip*H) * yipp

            yyip = (hi*hip) / (him*(him + hi)*H) * yim
            yyip -= (hip*(him + hi)) / (him*hi*(hi + hip)) * yi
            yyip += ((him + 2*hi)*hip - (him + hi)*hi) / ((him + hi)*hi*hip) * yip
            yyip += ((him + hi)*hi) / ((hi + hip)*hip*H) * yipp

            yyipp = - ((hi + hip)*hip) / (him*(him + hi)*H) * yim
            yyipp += (hip*H) / (him*hi*(hi + hip)) * yi
            yyipp -= ((hi + hip) * H) / ((him + hi)*hi*hip) * yip
            yyipp += ((2*hip + hi)*H + hip*(hi + hip)) / ((hi + hip)*hip*H) * yipp

            β2 = (hi + hip)^2 * (abs(yyip - yyi) / hi - abs(yyi - yyim) / him)^2
            β3 = (him + hi)^2 * (abs(yyipp - yyip) / hip - abs(yyip - yyi) / hi)^2

            prevβidx = i
        end
        γ2 = - (x - xipp) / (xipp - xim)
        γ3 = (x - xim) / (xipp - xim)

        α2 = γ2 / (ε + β2)
        α3 = γ3 / (ε + β3)

        ω2 = α2 / (α2 + α3)
        ω3 = α3 / (α2 + α3)

        fs[idx] = ω2*q2 + ω3*q3
    end
    return fs
end

And the Python version:

from numba import njit

@njit(cache=True)
def weno4_impl(xs, xp, fp):
    Ngrid = xp.shape[0]
    Eps = 1e-6
    fs = np.zeros_like(xs)

    left = fp[0]
    right = fp[-1]

    prevBetaIdx = -1
    for idx, x in enumerate(xs):
        i = np.searchsorted(xp, x, side='right') - 1

        if x < xp[0]:
            fs[idx] = left
            continue
        elif x > xp[-1]:
            fs[idx] = right
            continue

        if i == Ngrid - 1:
            i -= 1

        if i == 0:
            xi = xp[i]
            xip = xp[i+1]
            xipp = xp[i+2]

            hi = xip - xi
            hip = xipp - xip

            yi = fp[i]
            yip = fp[i+1]
            yipp = fp[i+2]

            q3 = yi * ((x - xip) * (x - xipp)) / (hi * (hi + hip))
            q3 -= yip * ((x - xi) * (x - xipp)) / (hi * hip)
            q3 += yipp * ((x - xi) * (x - xip)) / ((hi + hip) * hip)

            fs[idx] = q3
            continue
        elif i == Ngrid - 2:
            xim = xp[i-1]
            xi = xp[i]
            xip = xp[i+1]

            him = xi - xim
            hi = xip - xi

            yim = fp[i-1]
            yi = fp[i]
            yip = fp[i+1]

            q2 = yim * ((x - xi) * (x - xip)) / (him * (him + hi))
            q2 -= yi * ((x - xim) * (x - xip)) / (him * hi)
            q2 += yip * ((x - xim) * (x - xi)) / ((him + hi) * hi)

            fs[idx] = q2
            continue

        xim = xp[i-1]
        xi = xp[i]
        xip = xp[i+1]
        xipp = xp[i+2]

        him = xi - xim
        hi = xip - xi
        hip = xipp - xip

        yim = fp[i-1]
        yi = fp[i]
        yip = fp[i+1]
        yipp = fp[i+2]

        q2 = yim * ((x - xi) * (x - xip)) / (him * (him + hi))
        q2 -= yi * ((x - xim) * (x - xip)) / (him * hi)
        q2 += yip * ((x - xim) * (x - xi)) / ((him + hi) * hi)

        q3 = yi * ((x - xip) * (x - xipp)) / (hi * (hi + hip))
        q3 -= yip * ((x - xi) * (x - xipp)) / (hi * hip)
        q3 += yipp * ((x - xi) * (x - xip)) / ((hi + hip) * hip)

        if i != prevBetaIdx:
            H = him + hi + hip
            yyim = - ((2*him + hi)*H + him*(him + hi)) / (him*(him + hi)*H) * yim
            yyim += ((him + hi)*H) / (him*hi*(hi + hip)) * yi
            yyim -= (him*H) / ((him + hi)*hi*hip) * yip
            yyim += (him*(him + hi)) / ((hi + hip)*hip*H) * yipp

            yyi = - (hi*(hi + hip)) / (him*(him + hi)*H) * yim
            yyi += (hi*(hi + hip) - him*(2*hi + hip)) / (him*hi*(hi + hip)) * yi
            yyi += (him*(hi + hip)) / ((him + hi)*hi*hip) * yip
            yyi -= (him*hi) / ((hi + hip)*hip*H) * yipp

            yyip = (hi*hip) / (him*(him + hi)*H) * yim
            yyip -= (hip*(him + hi)) / (him*hi*(hi + hip)) * yi
            yyip += ((him + 2*hi)*hip - (him + hi)*hi) / ((him + hi)*hi*hip) * yip
            yyip += ((him + hi)*hi) / ((hi + hip)*hip*H) * yipp

            yyipp = - ((hi + hip)*hip) / (him*(him + hi)*H) * yim
            yyipp += (hip*H) / (him*hi*(hi + hip)) * yi
            yyipp -= ((hi + hip) * H) / ((him + hi)*hi*hip) * yip
            yyipp += ((2*hip + hi)*H + hip*(hi + hip)) / ((hi + hip)*hip*H) * yipp

            beta2 = (hi + hip)**2 * (abs(yyip - yyi) / hi - abs(yyi - yyim) / him)**2
            beta3 = (him + hi)**2 * (abs(yyipp - yyip) / hip - abs(yyip - yyi) / hi)**2

            prevBetaIdx = i

        gamma2 = - (x - xipp) / (xipp - xim)
        gamma3 = (x - xim) / (xipp - xim)

        alpha2 = gamma2 / (Eps + beta2)
        alpha3 = gamma3 / (Eps + beta3)

        omega2 = alpha2 / (alpha2 + alpha3)
        omega3 = alpha3 / (alpha2 + alpha3)

        fs[idx] = omega2 * q2 + omega3 * q3

    return fs

I’ve tested the Julia version on the REPL with the following data and calls:

const xx = [-0.92187417, -0.89761267, -0.87991064, -0.86419928, -0.61695843,
            -0.43402368, -0.41066999, -0.28437279, -0.03294275,  0.06117351,
             0.10350274,  0.15120579,  0.19502651,  0.27504179,  0.30483723,
             0.31266704,  0.57397092,  0.72808421,  0.75458105,  0.89136637]

const yy = [3.26564982, 3.13231235, 3.03788469, 2.95633998, 2.07793277,
            2.07177976, 2.11368258, 2.49333035, 3.80266509, 0.36498422,
            0.61108517, 0.87643985, 1.10453135, 1.4692657 , 1.58452945,
            1.61275516, 1.97720725, 1.63532625, 1.5387338 , 0.90130036]

const xi = collect(LinRange(minimum(xx), maximum(xx), 10001))

using BenchmarkTools

@benchmark weno4_impl(xi, xx, yy)

And the python version in ipython with:

xp = np.array([-0.92187417, -0.89761267, -0.87991064, -0.86419928, -0.61695843,
               -0.43402368, -0.41066999, -0.28437279, -0.03294275,  0.06117351,
                0.10350274,  0.15120579,  0.19502651,  0.27504179,  0.30483723,
                0.31266704,  0.57397092,  0.72808421,  0.75458105,  0.89136637])

yp = np.array([3.26564982, 3.13231235, 3.03788469, 2.95633998, 2.07793277,
               2.07177976, 2.11368258, 2.49333035, 3.80266509, 0.36498422,
               0.61108517, 0.87643985, 1.10453135, 1.4692657 , 1.58452945,
               1.61275516, 1.97720725, 1.63532625, 1.5387338 , 0.90130036])

x = np.linspace(xp.min(), xp.max(), 10001)

%timeit weno4_impl(x, xp, yp)

For this example with 10,001 points I get about 300 μs in Julia and 180 μs in Python.

Leaving here a few details on the system:

Julia: 1.6.4
Python: 3.9.6, numpy 1.20.3, numba 0.54.1
Darwin 20.6.0, all tests run with single thread

Any help would be most appreciated!

5 Likes

I haven’t found the source yet, but if you do @code_warntype weno4_impl(xi, xx, yy) there is a type instability in the code somewhere that could be the problem.

What happens if you stick @fastmath on this? I think Julia is more strict about what floating point rearrangement is allowed. I see a 10% speedup from this.

1 Like

I’m not seeing type instability here. What are you seeing?

image

That’s part of the @code_warntype output. Admittedly, I am not an expert at reading these so I don’t know if that one is a big deal or not.

The Julia version is also missing a matching continue statement for the elseif x > xp[end] condition near the beginning of the loop

1 Like

that’s just iteration. yellow lines in code warntype generally are fine.

2 Likes

You can also save some time switching

fs = zeros(typeof(xs[1]), size(xs))   # 4.9 us 

to

fs = similar(xs)  # 680 ns

since there sems to be no need to fill the fs array with zeros.

1 Like

Thanks. That was a bug indeed. But only affects the end point, so not much change in performance.

1 Like

Thanks for the suggestions so far. I could find no type instability. Using @fastmath speeds it up by 12% or so, although I’m always a bit wary of using it because of accuracy (in the end, this interpolation is used so it is more accurate), although in this example the difference in values is quite low. Swapping the zeros() with similar() had also occurred to me, but it makes very little difference in the performance, and I usually prefer to go safer with zeros than fill an array with any values.

Wonder if a possible optimisation is to rebuild the for loop without the continue statements and then use @simd, although from previous experience I find the julia code is already so optimised that @simd makes little difference. Maybe some fused-multiply-add since there are so many statements? Although I’m not sure how to best do it since the statements are all different.

Most of the time is spent on searchsortedlast, so that one appears to be suboptimal (any chance the numba version is parallel in some way? Is it doing the same thing?)

If I change that to something simple as:

        #i = searchsortedlast(xp, x)
        i = firstindex(xp) + 1
        while @inbounds xp[i] < x
            i += 1
        end
        i -= 1

I get already a significant improvement:

julia> include("./numba.jl")
  209.847 μs (2 allocations: 78.23 KiB)
  166.307 μs (2 allocations: 78.23 KiB) # now one

not enough to explain the complete difference, probably.

yet, that in unsafe (as it can loop forever, with the proper condition the performance becomes similar to the previous one):

        i = firstindex(xp) + 1
        @inbounds while xp[i] < x && i < lastindex(xp)
            i += 1
        end
        i -= 1
  200.369 μs (2 allocations: 78.23 KiB)
4 Likes

Is xs always sorted? If so, you can get a small asymptotic speedup since you know that i always increases.

For comparison, here is the searchsortedlast implementation in Base. The invocation is essentially searchsortedlast(xp, x, firstindex(xp), lastindex(xp), Base.Order.Forward).

https://github.com/JuliaLang/julia/blob/c2e93115e7f33673ea5843b9caa90390380f032c/base/sort.jl#L195-L208

2 Likes

leandromartinez98 said:
"Most of the time is spent on searchsortedlast "

for x < xp[1] || x > xp[end] the result of that search is not used

        i = searchsortedlast(xp, x)
        if x < xp[1]
            fs[idx] = left
            continue
        elseif x > xp[end]
            fs[idx] = right
            continue
        end
        if x < xp[1]
            fs[idx] = left
            continue
        elseif x > xp[end]
            fs[idx] = right
            continue
        end

        i = searchsortedlast(xp, x)

That saved nearly 2 μs or 0.3%!

julia> @benchmark weno4_impl(xi, xx, yy)
BenchmarkTools.Trial: 8501 samples with 1 evaluation.
 Range (min … max):  577.757 μs …  3.303 ms  ┊ GC (min … max): 0.00% … 81.19%
 Time  (median):     579.479 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   583.961 μs ± 45.629 μs  ┊ GC (mean ± σ):  0.25% ±  2.35%

   ▆█▇▅▁▄▄▃▂▁            ▁▃▃▂               ▁▁▁▁               ▂
  ▇███████████▆▄▅▁▃▃▃▁▃▃▇█████▇█▇▇▆▅▅▅▄▅▃▆█████████████▅▆▄▄▄▆▄ █
  578 μs        Histogram: log(frequency) by time       610 μs <

 Memory estimate: 78.23 KiB, allocs estimate: 2.

julia> @benchmark weno4_impl(xi, xx, yy)
BenchmarkTools.Trial: 8529 samples with 1 evaluation.
 Range (min … max):  576.715 μs …  1.633 ms  ┊ GC (min … max): 0.00% … 61.80%
 Time  (median):     578.243 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   581.991 μs ± 39.609 μs  ┊ GC (mean ± σ):  0.26% ±  2.44%

  ▄██▅▂▄▄▃▁                          ▁▁▁▁                      ▂
  ██████████▅▄▃▃▄▁▃▄▁▄▅▄▁▁▄▄▄▃▅▄▅▅▅███████████▇▇▆▆▄▅▃▄▄▃▄▁▃▁▁▃ █
  577 μs        Histogram: log(frequency) by time       615 μs <

 Memory estimate: 78.23 KiB, allocs estimate: 2.
2 Likes

Uhm… I tried just replacing the search for some different indexes to remove completely the call, and that reduced to 40-100 microseconds the running time here. But it may be a wrong evaluation, it appeared to me that what follows was not too much dependent on the resulting index.

Anyway, just for the records, here I don’t get Python to be faster:

In [5]: %timeit weno4_impl(x, xp, yp)                                                                                                  
289 µs ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

julia> include("./numba.jl")
  209.204 μs (2 allocations: 78.23 KiB)

I confirmed this with a simplified test.
Remove all the parts that actually do the calculations:

julia

function weno4_impl(xs::Vector{T}, xp::Vector{T}, fp::Vector{T}) where T
    fs = zero(xs)
    Ngrid = size(xp)[1]
    prevβidx = -1
    for (idx, x) in enumerate(xs)
        i = searchsortedlast(xp, x)
        if x < xp[1]
            fs[idx] = i
            continue
        elseif x > xp[end]
            fs[idx] = i
            continue
        end   

        if i == Ngrid
            i -= 1
        end
        if i == 1
            fs[idx] = i
            continue
        elseif i == Ngrid - 1
            fs[idx] = i
            continue
        end

        if i != prevβidx
            prevβidx = i
        end

        fs[idx] = i
    end
    return fs
end


const xp = [-0.92187417, -0.89761267, -0.87991064, -0.86419928, -0.61695843,
            -0.43402368, -0.41066999, -0.28437279, -0.03294275,  0.06117351,
             0.10350274,  0.15120579,  0.19502651,  0.27504179,  0.30483723,
             0.31266704,  0.57397092,  0.72808421,  0.75458105,  0.89136637]
const yp = [3.26564982, 3.13231235, 3.03788469, 2.95633998, 2.07793277,
            2.07177976, 2.11368258, 2.49333035, 3.80266509, 0.36498422,
            0.61108517, 0.87643985, 1.10453135, 1.4692657 , 1.58452945,
            1.61275516, 1.97720725, 1.63532625, 1.5387338 , 0.90130036]
const x = collect(LinRange(minimum(xp), maximum(xp), 10001))


using BenchmarkTools
@benchmark weno4_impl(x, xp, yp)  evals=1000
BenchmarkTools.Trial: 28 samples with 1000 evaluations.
 Range (min … max):  177.763 μs … 192.997 μs  ┊ GC (min … max): 1.61% … 3.68%
 Time  (median):     182.405 μs               ┊ GC (median):    3.34%
 Time  (mean ± σ):   183.199 μs ±   3.773 μs  ┊ GC (mean ± σ):  3.18% ± 0.92%

   ▃        ▃      █               ▃
  ▇█▁▁▇▁▇▁▇▁█▁▇▇▁▇▁█▁▇▇▁▁▁▇▁▇▁▇▇▇▁▁█▇▁▁▁▇▁▇▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▇ ▁
  178 μs           Histogram: frequency by time          193 μs <

 Memory estimate: 78.27 KiB, allocs estimate: 2.

Python

import numpy as np
from numba import njit

@njit(cache=True)
def weno4_impl(xs, xp, fp):
    Ngrid = xp.shape[0]
    fs = np.zeros_like(xs)
    prevBetaIdx = -1
    for idx, x in enumerate(xs):
        i = np.searchsorted(xp, x, side='right') - 1

        if x < xp[0]:
            fs[idx] = i
            continue
        elif x > xp[-1]:
            fs[idx] = i
            continue

        if i == Ngrid - 1:
            i -= 1
        
        if i == 0:
            fs[idx] = i
            continue
        elif i == Ngrid - 2:
            fs[idx] = i
            continue

        if i != prevBetaIdx:
            prevBetaIdx = i

        fs[idx] = i

    return fs

xp = np.array([-0.92187417, -0.89761267, -0.87991064, -0.86419928, -0.61695843,
               -0.43402368, -0.41066999, -0.28437279, -0.03294275,  0.06117351,
                0.10350274,  0.15120579,  0.19502651,  0.27504179,  0.30483723,
                0.31266704,  0.57397092,  0.72808421,  0.75458105,  0.89136637])
yp = np.array([3.26564982, 3.13231235, 3.03788469, 2.95633998, 2.07793277,
               2.07177976, 2.11368258, 2.49333035, 3.80266509, 0.36498422,
               0.61108517, 0.87643985, 1.10453135, 1.4692657 , 1.58452945,
               1.61275516, 1.97720725, 1.63532625, 1.5387338 , 0.90130036])
x = np.linspace(xp.min(), xp.max(), 10001)


%timeit -r 20  weno4_impl(x, xp, yp)
58.6 µs ± 4.9 µs per loop (mean ± std. dev. of 20 runs, 1 loop each)

julia vs numba

  • searchsorted: 182.405 μs vs 58.6 µs
  • calculation: 123.5 μs vs 124.4 μs
  • total: 305.921 μs ± 3.122 μs vs 183 µs ± 1.58 µs
1 Like

Is there anything special I should do to test the python version? I have installed numba, added import numpy as np to the top of the code, and the rest is just a copy/paste from what is there.

Here I get ~280 us for numba as ~200 us in Julia.

Cool! Code seems to be JET-clean.

Tested with

using JET
@report_opt weno4_impl(xi, xx, yy)

Thank you for the great discussion. This community is amazing!

Yes, in this case xs should be always sorted. Thank you @lmiq for identifying the issue was coming from searchsortedlast, and for your suggestion! I adapted your first suggestion to fit in the if/else part on the edges (as noted by @lawless-m , this should have been in the beginning of the loop), and since the extreme cases are taken care of, no need for the second condition in the while.

Adding an extra @inbounds on the main loop shaves a bit more time, so now I’m down to 168.525 μs, which is almost twice as fast as the original! It is only slightly faster than numba for this number of points (10,001), but actually close to twice as fast as numba for fewer points (which is closer to my application).

Strange that searchsorted was to blame. I always though that the base version would be more optimised than a quick function I would write, but it is not the case here. The version used in Python comes from deep within numpy, and seems heavily optimised (although the readability really hurts, shocking difference from the Julia base code):

https://github.com/numpy/numpy/blob/84e0707afa587e7655410561324ac36085db2b95/numpy/core/src/multiarray/item_selection.c#L1813

13 Likes

The beautiful thing is that helping other people also helps yourself (at least in my case).

The problems I solve in my day job are not as rich as the ones people bring here. With the added luxury that if I can’t solve them, it’s not my problem :wink:

21 Likes