Solve linear systems inside CUDA kernel function

Hello everyone, I was wondering how to solve a small linear system Ax = b inside a CUDA kernel function, where A, x and b are MMatrix and MVector. I want to do this with x = A\b, but it does not work inside a kernel, is there a solution to this problem?

If you use static arrays it won’t be an issue. This is done in DiffEqGPU.jl

Hi Chris, unfortunately I haven’t been able to solve this problem yet. I hope this code snippet illustrates the problem:

using CUDA, StaticArrays

const N = 20

function test()
    i = (blockIdx().x-1)* blockDim().x + threadIdx().x

    # each thread has unique A and b
    a = @MMatrix rand(Float64, N, N)
    b = @MVector rand(Float64, N)
    c = MVector{N, Float64}(undef)

    # This works
    c = a * b

    # But this does not
    c = a \ b

    return
end

@cuda threads=10 test()

Also, is it possible to do batched small linear systems solving with CUBLAS or CUSOLVER?

You want to use a SMatrix and Svector not a MMatrix, Mvector I assume.

In my case, A and b will be constructed inside the kernel, so they have to be mutable. Moreover, use SMatrix and SVector does not help in the snippet above.

That’s very unlikely to work. You cannot dynamically allocate memory inside a GPU kernel (see also this recent post: Modifying a thread-local vector within CUDA Dynamic Parallelism - #2 by vchuravy).

What should work though is to allocate all CuArrays outside the kernel, then inside the kernel convert the relevant views into your arrays into SMatrix/SVectors and do the solve on StaticArrays only. (I don’t have access to a GPU atm to check)

Allocate memory with MMatrix and MVector works fine, I think the problem is the \ operation has some allocations, I’ve tried to implement the Gauss elimination method to solve linear equations, it works well on GPU, but I’m worried about its performance.

Hi, can you try for some N < 14? IIUC, there were some allocations here: StaticArrays.jl/src/solve.jl at master · JuliaArrays/StaticArrays.jl · GitHub

We can probably try to get that dispatch setup in LinearSolve.jl but not sure as the previous approach may be done for performance reasons.

Yes, you are right! For N \leq 14 it works well. But in my case, the typical size is N=[20,200]. I think with my implementation of Gauss elimination will be faster than A\b if N \leq 14.