Hello,
My friend is working on a problem solving a 3+1 Dimensional non-linear PDE using a two-step operator splitting scheme. His code is in MATLAB and I suggested he try out Julia as it gives you more of a toolkit to optimise code.
The first step is a classical fully implicit finite difference scheme related to one of the spacial dimensions. The second step uses an explicit scheme for the Hopf-Lax formula in 2D (over 1D control), whose operator is non-linear as expected. This involves the remaining two spacial dimensions. Combining these two step comprise one splitting time stepping, so we have to use a for loop over this.
In either of the two steps, we have to apply the relevant operations over a 4D array, which on its own requires a lot of memory usage. Moreover, in the second step, we also store a re-indexing 3D array that is related to the non-linear operator, and is necessary to align the pre-computed 3D subarrays of the 4D array of interest.
After profiling the code I have found that the key steps that caused the slowdown were using \
and taking each element in a large complex matrix to a power. I would appreciate any tips for speeding up. I’ve looked into StaticArrays, but seen that they are generally more useful in the case of smaller arrays.
Here’s a MWP (minimal in terms of code, but accurate in terms of dimensions of the problem):
using LinearAlgebra
M = 100
G = 50
J = 50
H = 300
# This was faster without sparsity of A, not sure why
A = Matrix(Tridiagonal(rand(G-1,G-1)))
B = rand(Complex{Float64}, 2*J+1, G-1, H+1, M+1)
ϕ = rand(Complex{Float64}, 2*J+1, G+1, H+1, M+1)
for mm = 1:M
tempmatr = A \ (
reshape(
permutedims(
ϕ[:, 2:G, :, M + 2 - mm],
[2 1 3]),
G - 1, :)
+
reshape(
permutedims(
B[:, :, :, M + 2 - mm],
[2 1 3]),
G - 1, :)
)
ϕ[:, 2:G, :, M + 1 - mm] = permutedims(
reshape(
tempmatr,
G - 1,
2 * J + 1,
H + 1),
[2 1 3]
)
end
for completeness here is the other bottleneck:
γ = 0.2
prodqS = rand(Complex{Float64}, 2*J+1, G+1, H+1)
theta = -γ^(1/γ) * ϕ.^(1/γ) + repeat(convert(Array{Float64, 3}, prodqS), outer=(1, 1, 1, M + 1))