My problem is that I want to solve many small linear systems of size 3x3, and to do it in parallel, using LU factorization.
using LinearAlgebra
using TimerOutputs
using Random
reset_timer!()
function f_lu(N)
Random.seed!(1)
A = randn(3,3,N)
b = randn(3,N)
@timeit "Inverse LAPACK lu" begin
for n = 1:N
@views LAPACK.gesv!(A[:,:,n], b[:,n])
end
end
b
end
function f_lu_inbounds(N)
Random.seed!(1)
A = randn(3,3,N)
b = randn(3,N)
@timeit "Inverse LAPACK lu inbounds" begin
@inbounds for n = 1:N
@views LAPACK.gesv!(A[:,:,n], b[:,n])
end
end
b
end
function f_lu_threads(N)
Random.seed!(1)
A = randn(3,3,N)
b = randn(3,N)
@timeit "Inverse LAPACK lu threads" begin
Threads.@threads for n = 1:N
@views LAPACK.gesv!(A[:,:,n], b[:,n])
end
end
b
end
function f_lu_julia(N)
Random.seed!(1)
A = randn(3,3,N)
b = randn(3,N)
x = zeros(3,N)
@timeit "Inverse lu julia" begin
for n = 1:N
@views ldiv!(x[:,n], lu!(A[:,:,n]), b[:,n])
end
end
x
end
N = 4000*10^4
fs = [f_lu_inbounds, f_lu, f_lu_julia, f_lu_threads]
res = map(f ->f(N), fs)
for n = 2:length(fs)
@assert res[n] β res[1]
end
print_timer()
yields
julia> include("test_lapack.jl")
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Time Allocations
ββββββββββββββββββββββ βββββββββββββββββββββββ
Tot / % measured: 291s / 37.1% 45.3GiB / 60.5%
Section ncalls time %tot avg alloc %tot avg
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Inverse lu julia 1 34.8s 32.2% 34.8s 14.9GiB 54.3% 14.9GiB
Inverse LAPACK lu threads 1 28.1s 26.0% 28.1s 4.17GiB 15.2% 4.17GiB
Inverse LAPACK lu 1 25.4s 23.5% 25.4s 4.17GiB 15.2% 4.17GiB
Inverse LAPACK lu inbounds 1 19.7s 18.3% 19.7s 4.17GiB 15.2% 4.17GiB
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
julia> Threads.nthreads()
2
So first I was surprised by calling Lapack routing directly saved 30% in time. Secondly, even though I used two threads to solve the same problem, the computational time increased instead of decreasing, so probably I have misunderstood how threads work.
How would I do to speed up this type of calculations?