Minimal Example
As a minimal example, suppose that I have the following two structs defined,
struct Foo{T, M<:AbstractMatrix{T}}
A::M
B::M
end
struct Bar{T, M<:AbstractMatrix{T}}
C::M
foo::Foo{T, M}
end
as well as a function that operates on them,
function update(bar::Bar)
new_C = bar.foo.A * bar.C
new_B = bar.foo.B * new_C
return Bar(new_C, Foo(bar.foo.A, new_B))
end
If I have a batch of foos stored as foos::Vector{<:Foo}
, I can broadcast the function update
over this collection using update.(foos)
.
This will essentially iterate over the elements of foos
and perform update
on each of them.
The Problem
This works fine on the CPU, but such iteration is incredibly slow on the GPU since one kernel call will be required for each element of the batch. Instead we would like to use individual kernels to perform each of the multiplications inside of update
over the whole batch of matrices, e.g., using CUDA.CUBLAS.gemm_strided_batched
.
In this case, I imagine the input to the function would be some custom type Batched{Bar{T, CuMatrix{T}}}
which behaves like a Vector{Bar{T, CuMatrix{T}}}
when it comes to broadcasting but is backed by three CuArray{T, 3}
representing the A
s, B
s and C
s as contiguous memory on the GPU.
I could obviously write a custom batched implementation my specific Foo
, Bar
and update
, but I need a solution that generalises more widely. Specifically, I would like to take advantage of Julia’s broadcasting system to create a solution that works whenever:
- The structs only have fields that are array types or other such structs
- The function being applied only involves
getfield
, linear algebra operations, and creating new structs
Computationally (though perhaps not syntactically) an ideal broadcast will end up running these steps on the 3D CuArrays, say, As_data
, Bs_data
, Cs_data
that back the batched type:
function update_batched(bar::Batched{Bar})
# ...
new_Cs_data = CUDA.CUBLAS.gemm_strided_batched(As_data, Cs_data)
new_Bs_data = CUDA.CUBLAS.gemm_strided_batched(Bs_data, new_Cs_data)
# return ...
end
Question
I appreciate that this is quite a large and broad question, so I’m not expecting a full solution.
Instead, I would appreciate any pointers to packages/projects that have tackled this or similar problems already.
Otherwise, some high-level suggestions as to what the cleanest and jank-free way of achieving this, would be greatly appreciated.
In some sense this is similar to how NumPy handles broadcasting. When using, e.g. TensorFlow, I can call a function update
on 3D arrays and it will act like I’m calling it on individual matrices, batching the remaining dimensions automatically in an efficient fashion. Can this behaviour be replicated in Julia in a more principled/idiomatic fashion?
Initial Thoughts
You can probably ignore this section but I thought it was worth also mentioning someone in this post the ideas that I have.
Idea 1: Fake AbstractMatrix type (essentially what NumPy does)
Suppose on top of our batched types, we have type BroadcastedMatrix
that is backed by a 3D CuArray
but is a subtype of AbstractMatrix
. This is philosophically wrong, but it allows us to instantiate, e.g, a `Foo(::BroadcastedMatrix, ::BroadcastedMatrix).
I’ve created some functions that recurses through a Batched{ET}
and convert it to a ET
where all of the matrices inside are BroadcastedMatrix
s. You can then just call update
(without broadcastingon this
ETand define some dispatching rule for how, e.g,
*(::BroadcastedMatrix, ::BroadcastedMatrix)` should behave.
This actually works surprisingly well, but I just feel that subtyping from AbstractMatrix
with something that inherently is a 3D array (even though it kind of acts like a matrix in this context) is a bad idea and will lead to undefined behaviour.
This solution feels like the closest to replicating exactly how NumPy broadcasting works.
Idea 2: IR Manipulation
It roughly feels that when we call update.
on a batched foo
, we would like the broadcast to be interrupted and converted into a function that looks something like this
function batched_update(bar::Batched{Bar{T}}) where {T}
# Extract backing CuArrays assuming a structure like StructArrays
A = bar.data.foo.data.A
B = bar.data.foo.data.B
C = bar.data.C
# Perform operations by calling CuBLAS wrappers
new_C = batch_call(*, A, C)
new_B = batch_call(*, B, new_C)
# The Batched type would use StructArrays for creating composite structs
return Batched{Bar{...}}(
Batched{CuMatrix{T}}(new_C),
Batched{Foo}{...}(Batched{CuMatrix{T}}(A), Batched{CuMatrix{T}}(B_new))
)
end
This transformation seems reasonably simple to describe and could potentially be generated through some IR manipulation of the original update
function. I worry that this is massively overkill though and could lead to unexpected behaviour.