Performance of a key step in Strang Splitting method

Hello,

My friend is working on a problem solving a 3+1 Dimensional non-linear PDE using a two-step operator splitting scheme. His code is in MATLAB and I suggested he try out Julia as it gives you more of a toolkit to optimise code.

The first step is a classical fully implicit finite difference scheme related to one of the spacial dimensions. The second step uses an explicit scheme for the Hopf-Lax formula in 2D (over 1D control), whose operator is non-linear as expected. This involves the remaining two spacial dimensions. Combining these two step comprise one splitting time stepping, so we have to use a for loop over this.

In either of the two steps, we have to apply the relevant operations over a 4D array, which on its own requires a lot of memory usage. Moreover, in the second step, we also store a re-indexing 3D array that is related to the non-linear operator, and is necessary to align the pre-computed 3D subarrays of the 4D array of interest.

After profiling the code I have found that the key steps that caused the slowdown were using \ and taking each element in a large complex matrix to a power. I would appreciate any tips for speeding up. I’ve looked into StaticArrays, but seen that they are generally more useful in the case of smaller arrays.

Here’s a MWP (minimal in terms of code, but accurate in terms of dimensions of the problem):

using LinearAlgebra
M = 100
G = 50
J = 50
H = 300

# This was faster without sparsity of A, not sure why
A = Matrix(Tridiagonal(rand(G-1,G-1)))  
B = rand(Complex{Float64}, 2*J+1, G-1, H+1, M+1)
ϕ = rand(Complex{Float64}, 2*J+1, G+1, H+1, M+1)

for mm = 1:M
    tempmatr = A \ (
        reshape(
            permutedims(
                ϕ[:, 2:G, :, M + 2 - mm],
                [2 1 3]),
            G - 1, :)
            +
        reshape(
            permutedims(
                B[:, :, :, M + 2 - mm],
                [2 1 3]),
            G - 1, :)
        )
    ϕ[:, 2:G, :, M + 1 - mm] = permutedims(
        reshape(
            tempmatr,
            G - 1,
            2 * J + 1,
            H + 1),
        [2 1 3]
    )
end

for completeness here is the other bottleneck:

γ = 0.2
prodqS = rand(Complex{Float64}, 2*J+1, G+1, H+1)
theta = -γ^(1/γ) * ϕ.^(1/γ) + repeat(convert(Array{Float64, 3}, prodqS), outer=(1, 1, 1, M + 1))

How are you doing the benchmarking? Have you read the performance tips in the Julia manual?

In particular, everything needs to be inside a function, not at global scope.

3 Likes

In addition to the general advice (read the performance tips, don’t use globals in performance-critical code), I wanted to mention:

Whenever you do repeated solves like this with the same matrix, you should precompute the appropriate matrix factorization (e.g. the LU factorization for a general nonsymmetric matrix). e.g. do LU = lu(A) and then use LU \ ... instead of A \ ....

(Also, as mentioned in the performance tips, it is often a good idea to precompute array outputs. Also note that slices make copies unless you use @views, which is also mentioned in the performance tips. Did I mention that you should read the performance tips?)

If γ = 0.2, then it will be much faster to do ϕ.^5 than ϕ.^(1/γ) (ϕ.^5.0), because exponentiating a complex number to an integer power uses a different algorithm. Note also that ... + repeat(X, ...) can typically just use .... .+ X using broadcasting.

In general, the code using extensive reshape and permutedims etcetera seems very “Matlaby.” e.g. it would probably be better to rearrange your arrays so that the permutedims calls are not needed.

8 Likes

Putting your first block in a function, I get

julia> using BenchmarkTools

julia> @btime f1($M, $G, $J, $H, $A, $B, $ϕ)
  13.653 s (5400 allocations: 19.98 GiB)

Applying some of @stevengj’s tips,

function f2(M, G, J, H, A, B, ϕ)
    tmp = similar(ϕ, G-1, (2*J+1)*(H + 1))
    Â = factorize(A)
    for mm = 1:M
        m_idx = M + 2 - mm
        @inbounds for hh = 1:H+1
            for jj = 1:2*J+1
                for gg = 1:G-1
                    tmp[gg, jj + (hh - 1)*(2*J + 1)] = ϕ[jj, gg+1, hh, m_idx] +
                                                       B[jj, gg,   hh, m_idx]
                end
            end
        end
        ldiv!(Â, tmp)
        @inbounds for hh = 1:H+1
            for jj = 1:2*J+1
                for gg = 1:G-1
                    ϕ[jj, gg + 1, hh, m_idx] = tmp[gg, jj + (hh - 1)*(2*J + 1)]
                end
            end
        end
    end
end

…which gives

julia> @btime f2($M, $G, $J, $H, $A, $B, $ϕ)
  3.211 s (14 allocations: 22.73 MiB)

We can shave a bit more off by multithreading the for loops with Threads.@threads:

julia> @btime f3($M, $G, $J, $H, $A, $B, $ϕ)
  2.805 s (16464 allocations: 24.35 MiB)
5 Likes

Once more, with feeling (and threads/SIMD/cache locality):

using LoopVectorization

function f4(M, G, J, H, A, B, ϕ)
    # set up thread-local storage
    Â = factorize(A)
    Âs = [deepcopy(Â) for i = 1:Threads.nthreads()]
    tmp = [similar(ϕ, G-1) for i = 1:Threads.nthreads()]
    # reinterpret arrays to get around LV.jl's limitations for complex numbers
    Bf = reinterpret(reshape, Float64, B)
    ϕf = reinterpret(reshape, Float64, ϕ)
    tmpf = reinterpret.(reshape, Float64, tmp)

    @inbounds for mm = 1:M
        m_idx = M + 2 - mm
        Threads.@threads for hh = 1:H+1
            tmpl = tmp[Threads.threadid()]
            tmpfl = tmpf[Threads.threadid()]
            Âl = Âs[Threads.threadid()]
            h_idx = (hh - 1)*jmax
            for jj = 1:2*J + 1
                @avx for gg = 1:G-1
                    tmpfl[1, gg] = ϕf[1, jj, gg+1, hh, m_idx] +
                                   Bf[1, jj, gg,   hh, m_idx]
                    tmpfl[2, gg] = ϕf[2, jj, gg+1, hh, m_idx] +
                                   Bf[2, jj, gg,   hh, m_idx]
                end
                ldiv!(Âl, tmpl)
                @avx for gg = 1:G-1
                    ϕf[1, jj, gg + 1, hh, m_idx] = tmpfl[1, gg]
                    ϕf[2, jj, gg + 1, hh, m_idx] = tmpfl[2, gg]
                end
            end
        end
    end
end
julia> @btime f4($M, $G, $J, $H, $A, $B, $ϕ)
  335.336 ms (8911 allocations: 915.33 KiB)
6 Likes

Thank you so much for your assistance in this. I am once again astonished by the Julia community’s ability to make true the claims that Julia is the language of the future, and their willingness to help. I have taken into account all of the suggestions so far (some of which I am frustrated that I did not think of myself) but I am struggling slightly with the reinterpretation step in f4. The reinterpret function only takes 2 arguments, type and object, and so I am confused as to what reshape is doing there. Indeed this code errors in my tests for this reason. Secondly to do with this, I see that the complex is interpreted as Float64, but is there not a need to re-reinterpret it after calculation? is this done in somewhere that I’m not seeing or is it not necessary at all as the results are the same?

Sorry, should have clarified that f4 uses Julia v1.6’s new three-arg reinterpret:

help?> reinterpret
[...]
reinterpret(reshape, T, A::AbstractArray{S}) -> B

  Change the type-interpretation of A while consuming or adding a "channel dimension."

reinterpret doesn’t allocate a new array, it just tells Julia to treat a block of memory as if it contains elements of a different type. Complex{Float64} is laid out in memory as (re::Float64, im::Float64), so we can reinterpret complex numbers to pairs of floats. Changes made to the reinterpreted array are reflected in the original array because they refer to the same location in memory. This lets us get around LoopVectorization.jl’s complex number limitations (mentioned on this page). For example,

julia> v = rand(Complex{Float64}, 3)
3-element Vector{ComplexF64}:
 0.49472217571377874 + 0.893983220751845im
  0.8870555454462015 + 0.4894258584197706im
   0.137878217676755 + 0.4586371384468253im

julia> vf = reinterpret(reshape, Float64, v)
2×3 reinterpret(reshape, Float64, ::Vector{ComplexF64}) with eltype Float64:
 0.494722  0.887056  0.137878
 0.893983  0.489426  0.458637

julia> vf[:, 1] .= 0
2-element view(reinterpret(reshape, Float64, ::Vector{ComplexF64}), 1:2, 1) with eltype Float64:
 0.0
 0.0

julia> v
3-element Vector{ComplexF64}:
                0.0 + 0.0im
 0.8870555454462015 + 0.4894258584197706im
  0.137878217676755 + 0.4586371384468253im
2 Likes

Man, I had no idea Julia was capable of this kind of magic. I’d only heard of reinterpreting in that fast inverse square root from Quake III. Just to confirm, are you getting the exact same results for ϕ in all of these cases? I ask because I, having run only my reconstructed f1 and f2 as I don’t have 1.6 installed, don’t get the same ϕ in both cases. Do you mind sharing your f1 and f3 for completeness sake*?

Oh and as far as discourse goes I’m not sure which post to mark as the “answer” as I’d like to thank everyone for helping out.

*and also because I am not so experienced and haven’t seen the @threads macro used before so might mess it up…

Yeah, we can implement a one-for-one translation of the C++ code for the fast inverse square root in Julia.

inline constexpr float Q_rsqrt( float number ) noexcept
{
	float const x2 = number * 0.5F;
	float const threehalfs = 1.5F;
	auto i = std::bit_cast<std::uint32_t>(number);
	i  = 0x5f3759df - ( i >> 1 );
	number = std::bit_cast<float>(i);
	number  *= threehalfs - ( x2 * number * number );
	return number;
}
function Q_rsqrt(number::Float32)
    x2 = 0.5f0*number
    threehalfs = 1.5f0
    i = reinterpret(UInt32, number)
    i = 0x5f3759df - ( i >> 1 )
    number = reinterpret(Float32, i)
    number *= threehalfs - (x2 * number * number)
    return number
end
julia> Q_rsqrt(1.2f0)
0.9128525f0

julia> 1/sqrt(1.2f0)
0.9128709f0

I had an indexing typo in f2 - I was using M + 2 - mm instead of M + 1 - mm for the index into ϕ. With that fixed, the results match. All versions are given here: https://gist.github.com/stillyslalom/2e7e0a5483a847ba3206f66dc3a740fa

**edit: also added f5 to the gist, which threads the ldiv! but doesn’t require LoopVectorization or v1.6. Comes in at 493 ms, which is a significant improvement over f3.

4 Likes

(Is it any faster than @fastmath inv(sqrt(x)), which is much more accurate? I’m not seeing any difference on my machine, even though LLVM does not seem to emit the rsqrtss instruction. It should also be possible to invoke rsqrtss directly with llvmcall in Julia to call the llvm.x86.sse.rsqrt.ps intrinsic.)

1 Like

no. Most modern CPUs have had fast custom inverse sqrt that has made the quake code obsolete for a while.

3 Likes