Fastest possible trilinear interpolation with `Interpolations.jl`

I am working on developing a machine learning application in which two instances of our CFD code work simultaneously at different resolutions and data is passed between the two instances. For this, I need to have fastest possible interpolations. At the moment I work with Interpolations.jl but I find the performance insufficient for my application and the interpolations are a huge bottleneck. Due to the absense of thread support in Interpolations.jl, I am launching the interpolations of different fields on separate threads. I am not sure whether I am using Interpolations.jl to its maximum extent as the benchmarks on the site are very fast. How can I speed up the code below that is a minimal example of my use case? I have a factor 2 or 3 ratio in grid points per dimension so I could potentially make use of repeating patterns.

using Interpolations
using BenchmarkTools

function interpolate_test!(a_hi, a_lo, x_hi, x_lo)
    interp_a = interpolate((x_lo, x_lo, x_lo), a_lo, (Gridded(Linear()), Gridded(Linear()), Gridded(Linear())))
    a_hi[:, :, :] .= interp_a(x_hi, x_hi, x_hi)

n_hi = 256
n_lo = 128

dx_hi = 1//n_hi
dx_lo = 1//n_lo

x_hi = 1//2*dx_hi:dx_hi:1
x_lo = -1//2*dx_lo:dx_lo:1+1//2*dx_lo

a0_lo = rand(n_lo+2, n_lo+2, n_lo+2)
a0_hi = zeros(n_hi, n_hi, n_hi)

a1_lo = rand(n_lo+2, n_lo+2, n_lo+2)
a1_hi = zeros(n_hi, n_hi, n_hi)

a2_lo = rand(n_lo+2, n_lo+2, n_lo+2)
a2_hi = zeros(n_hi, n_hi, n_hi)

a3_lo = rand(n_lo+2, n_lo+2, n_lo+2)
a3_hi = zeros(n_hi, n_hi, n_hi)

@btime begin
    @sync begin
        Threads.@spawn interpolate_test!(a0_hi, a0_lo, x_hi, x_lo)
        Threads.@spawn interpolate_test!(a1_hi, a1_lo, x_hi, x_lo)
        Threads.@spawn interpolate_test!(a2_hi, a2_lo, x_hi, x_lo)
        Threads.@spawn interpolate_test!(a3_hi, a3_lo, x_hi, x_lo)

In case you just need up- or downsampling you can use the code in NNlib or NNlibCUDA (upsample_trilinear). It works on GPU and the CPU code is threaded.

1 Like