I am trying to write a Julia port of the Weno4Interpolation python package, an algorithm based on Janett et al. (2019) to perform higher order interpolations.
The code itself is not complex, but has many statements, mostly simple algebraic expressions. This is the type of problem where I expected Julia would shine, but I was surprised to find that the Python version (using numba) consistently outperforms Julia, finishing in about 2/3 of the time in my system. I have some implementation in Julia and numba that is nearly identical line by line.
Since it is so simple, I’m a bit at odds on how to optimise it. I’ve tried the usual tricks: additional type annotations, @inbounds, etc, but just can’t further optimise the julia version. Any suggestions on how this code could be improved?
This is the Julia version:
function weno4_impl(xs, xp, fp)
Ngrid = size(xp)[1]
ε = 1e-6
fs = zeros(typeof(xs[1]), size(xs))
left = fp[1]
right = fp[end]
prevβidx = -1
β2 = 0.0
β3 = 0.0
for (idx, x) in enumerate(xs)
i = searchsortedlast(xp, x)
if x < xp[1]
fs[idx] = left
continue
elseif x > xp[end]
fs[idx] = right
end
if i == Ngrid
i -= 1
end
if i == 1
xi = xp[i]
xip = xp[i+1]
xipp = xp[i+2]
hi = xip - xi
hip = xipp - xip
yi = fp[i]
yip = fp[i+1]
yipp = fp[i+2]
q3 = yi * ((x - xip) * (x - xipp)) / (hi * (hi + hip))
q3 -= yip * ((x - xi) * (x - xipp)) / (hi * hip)
q3 += yipp * ((x - xi) * (x - xip)) / ((hi + hip) * hip)
fs[idx] = q3
continue
elseif i == Ngrid - 1
xim = xp[i-1]
xi = xp[i]
xip = xp[i+1]
him = xi - xim
hi = xip - xi
yim = fp[i-1]
yi = fp[i]
yip = fp[i+1]
q2 = yim * ((x - xi) * (x - xip)) / (him * (him + hi))
q2 -= yi * ((x - xim) * (x - xip)) / (him * hi)
q2 += yip * ((x - xim) * (x - xi)) / ((him + hi) * hi)
fs[idx] = q2
continue
end
xim = xp[i-1]
xi = xp[i]
xip = xp[i+1]
xipp = xp[i+2]
him = xi - xim
hi = xip - xi
hip = xipp - xip
yim = fp[i-1]
yi = fp[i]
yip = fp[i+1]
yipp = fp[i+2]
q2 = yim * ((x - xi) * (x - xip)) / (him * (him + hi))
q2 -= yi * ((x - xim) * (x - xip)) / (him * hi)
q2 += yip * ((x - xim) * (x - xi)) / ((him + hi) * hi)
q3 = yi * ((x - xip) * (x - xipp)) / (hi * (hi + hip))
q3 -= yip * ((x - xi) * (x - xipp)) / (hi * hip)
q3 += yipp * ((x - xi) * (x - xip)) / ((hi + hip) * hip)
if i != prevβidx
H = him + hi + hip
yyim = - ((2*him + hi)*H + him*(him + hi)) / (him*(him + hi)*H) * yim
yyim += ((him + hi)*H) / (him*hi*(hi + hip)) * yi
yyim -= (him*H) / ((him + hi)*hi*hip) * yip
yyim += (him*(him + hi)) / ((hi + hip)*hip*H) * yipp
yyi = - (hi*(hi + hip)) / (him*(him + hi)*H) * yim
yyi += (hi*(hi + hip) - him*(2*hi + hip)) / (him*hi*(hi + hip)) * yi
yyi += (him*(hi + hip)) / ((him + hi)*hi*hip) * yip
yyi -= (him*hi) / ((hi + hip)*hip*H) * yipp
yyip = (hi*hip) / (him*(him + hi)*H) * yim
yyip -= (hip*(him + hi)) / (him*hi*(hi + hip)) * yi
yyip += ((him + 2*hi)*hip - (him + hi)*hi) / ((him + hi)*hi*hip) * yip
yyip += ((him + hi)*hi) / ((hi + hip)*hip*H) * yipp
yyipp = - ((hi + hip)*hip) / (him*(him + hi)*H) * yim
yyipp += (hip*H) / (him*hi*(hi + hip)) * yi
yyipp -= ((hi + hip) * H) / ((him + hi)*hi*hip) * yip
yyipp += ((2*hip + hi)*H + hip*(hi + hip)) / ((hi + hip)*hip*H) * yipp
β2 = (hi + hip)^2 * (abs(yyip - yyi) / hi - abs(yyi - yyim) / him)^2
β3 = (him + hi)^2 * (abs(yyipp - yyip) / hip - abs(yyip - yyi) / hi)^2
prevβidx = i
end
γ2 = - (x - xipp) / (xipp - xim)
γ3 = (x - xim) / (xipp - xim)
α2 = γ2 / (ε + β2)
α3 = γ3 / (ε + β3)
ω2 = α2 / (α2 + α3)
ω3 = α3 / (α2 + α3)
fs[idx] = ω2*q2 + ω3*q3
end
return fs
end
And the Python version:
from numba import njit
@njit(cache=True)
def weno4_impl(xs, xp, fp):
Ngrid = xp.shape[0]
Eps = 1e-6
fs = np.zeros_like(xs)
left = fp[0]
right = fp[-1]
prevBetaIdx = -1
for idx, x in enumerate(xs):
i = np.searchsorted(xp, x, side='right') - 1
if x < xp[0]:
fs[idx] = left
continue
elif x > xp[-1]:
fs[idx] = right
continue
if i == Ngrid - 1:
i -= 1
if i == 0:
xi = xp[i]
xip = xp[i+1]
xipp = xp[i+2]
hi = xip - xi
hip = xipp - xip
yi = fp[i]
yip = fp[i+1]
yipp = fp[i+2]
q3 = yi * ((x - xip) * (x - xipp)) / (hi * (hi + hip))
q3 -= yip * ((x - xi) * (x - xipp)) / (hi * hip)
q3 += yipp * ((x - xi) * (x - xip)) / ((hi + hip) * hip)
fs[idx] = q3
continue
elif i == Ngrid - 2:
xim = xp[i-1]
xi = xp[i]
xip = xp[i+1]
him = xi - xim
hi = xip - xi
yim = fp[i-1]
yi = fp[i]
yip = fp[i+1]
q2 = yim * ((x - xi) * (x - xip)) / (him * (him + hi))
q2 -= yi * ((x - xim) * (x - xip)) / (him * hi)
q2 += yip * ((x - xim) * (x - xi)) / ((him + hi) * hi)
fs[idx] = q2
continue
xim = xp[i-1]
xi = xp[i]
xip = xp[i+1]
xipp = xp[i+2]
him = xi - xim
hi = xip - xi
hip = xipp - xip
yim = fp[i-1]
yi = fp[i]
yip = fp[i+1]
yipp = fp[i+2]
q2 = yim * ((x - xi) * (x - xip)) / (him * (him + hi))
q2 -= yi * ((x - xim) * (x - xip)) / (him * hi)
q2 += yip * ((x - xim) * (x - xi)) / ((him + hi) * hi)
q3 = yi * ((x - xip) * (x - xipp)) / (hi * (hi + hip))
q3 -= yip * ((x - xi) * (x - xipp)) / (hi * hip)
q3 += yipp * ((x - xi) * (x - xip)) / ((hi + hip) * hip)
if i != prevBetaIdx:
H = him + hi + hip
yyim = - ((2*him + hi)*H + him*(him + hi)) / (him*(him + hi)*H) * yim
yyim += ((him + hi)*H) / (him*hi*(hi + hip)) * yi
yyim -= (him*H) / ((him + hi)*hi*hip) * yip
yyim += (him*(him + hi)) / ((hi + hip)*hip*H) * yipp
yyi = - (hi*(hi + hip)) / (him*(him + hi)*H) * yim
yyi += (hi*(hi + hip) - him*(2*hi + hip)) / (him*hi*(hi + hip)) * yi
yyi += (him*(hi + hip)) / ((him + hi)*hi*hip) * yip
yyi -= (him*hi) / ((hi + hip)*hip*H) * yipp
yyip = (hi*hip) / (him*(him + hi)*H) * yim
yyip -= (hip*(him + hi)) / (him*hi*(hi + hip)) * yi
yyip += ((him + 2*hi)*hip - (him + hi)*hi) / ((him + hi)*hi*hip) * yip
yyip += ((him + hi)*hi) / ((hi + hip)*hip*H) * yipp
yyipp = - ((hi + hip)*hip) / (him*(him + hi)*H) * yim
yyipp += (hip*H) / (him*hi*(hi + hip)) * yi
yyipp -= ((hi + hip) * H) / ((him + hi)*hi*hip) * yip
yyipp += ((2*hip + hi)*H + hip*(hi + hip)) / ((hi + hip)*hip*H) * yipp
beta2 = (hi + hip)**2 * (abs(yyip - yyi) / hi - abs(yyi - yyim) / him)**2
beta3 = (him + hi)**2 * (abs(yyipp - yyip) / hip - abs(yyip - yyi) / hi)**2
prevBetaIdx = i
gamma2 = - (x - xipp) / (xipp - xim)
gamma3 = (x - xim) / (xipp - xim)
alpha2 = gamma2 / (Eps + beta2)
alpha3 = gamma3 / (Eps + beta3)
omega2 = alpha2 / (alpha2 + alpha3)
omega3 = alpha3 / (alpha2 + alpha3)
fs[idx] = omega2 * q2 + omega3 * q3
return fs
I’ve tested the Julia version on the REPL with the following data and calls:
const xx = [-0.92187417, -0.89761267, -0.87991064, -0.86419928, -0.61695843,
-0.43402368, -0.41066999, -0.28437279, -0.03294275, 0.06117351,
0.10350274, 0.15120579, 0.19502651, 0.27504179, 0.30483723,
0.31266704, 0.57397092, 0.72808421, 0.75458105, 0.89136637]
const yy = [3.26564982, 3.13231235, 3.03788469, 2.95633998, 2.07793277,
2.07177976, 2.11368258, 2.49333035, 3.80266509, 0.36498422,
0.61108517, 0.87643985, 1.10453135, 1.4692657 , 1.58452945,
1.61275516, 1.97720725, 1.63532625, 1.5387338 , 0.90130036]
const xi = collect(LinRange(minimum(xx), maximum(xx), 10001))
using BenchmarkTools
@benchmark weno4_impl(xi, xx, yy)
And the python version in ipython with:
xp = np.array([-0.92187417, -0.89761267, -0.87991064, -0.86419928, -0.61695843,
-0.43402368, -0.41066999, -0.28437279, -0.03294275, 0.06117351,
0.10350274, 0.15120579, 0.19502651, 0.27504179, 0.30483723,
0.31266704, 0.57397092, 0.72808421, 0.75458105, 0.89136637])
yp = np.array([3.26564982, 3.13231235, 3.03788469, 2.95633998, 2.07793277,
2.07177976, 2.11368258, 2.49333035, 3.80266509, 0.36498422,
0.61108517, 0.87643985, 1.10453135, 1.4692657 , 1.58452945,
1.61275516, 1.97720725, 1.63532625, 1.5387338 , 0.90130036])
x = np.linspace(xp.min(), xp.max(), 10001)
%timeit weno4_impl(x, xp, yp)
For this example with 10,001 points I get about 300 μs in Julia and 180 μs in Python.
Leaving here a few details on the system:
Julia: 1.6.4
Python: 3.9.6, numpy 1.20.3, numba 0.54.1
Darwin 20.6.0, all tests run with single thread
Any help would be most appreciated!