Custom (NumPy style) broadcasting rule that avoids iterating over elements (for GPU-acceleration)

Minimal Example

As a minimal example, suppose that I have the following two structs defined,

struct Foo{T, M<:AbstractMatrix{T}}
    A::M
    B::M
end

struct Bar{T, M<:AbstractMatrix{T}}
    C::M
    foo::Foo{T, M}
end

as well as a function that operates on them,

function update(bar::Bar)
    new_C = bar.foo.A * bar.C
    new_B = bar.foo.B * new_C
    return Bar(new_C, Foo(bar.foo.A, new_B))
end

If I have a batch of foos stored as foos::Vector{<:Foo}, I can broadcast the function update over this collection using update.(foos).

This will essentially iterate over the elements of foos and perform update on each of them.

The Problem

This works fine on the CPU, but such iteration is incredibly slow on the GPU since one kernel call will be required for each element of the batch. Instead we would like to use individual kernels to perform each of the multiplications inside of update over the whole batch of matrices, e.g., using CUDA.CUBLAS.gemm_strided_batched.

In this case, I imagine the input to the function would be some custom type Batched{Bar{T, CuMatrix{T}}} which behaves like a Vector{Bar{T, CuMatrix{T}}} when it comes to broadcasting but is backed by three CuArray{T, 3} representing the As, Bs and Cs as contiguous memory on the GPU.

I could obviously write a custom batched implementation my specific Foo, Bar and update, but I need a solution that generalises more widely. Specifically, I would like to take advantage of Julia’s broadcasting system to create a solution that works whenever:

  • The structs only have fields that are array types or other such structs
  • The function being applied only involves getfield, linear algebra operations, and creating new structs

Computationally (though perhaps not syntactically) an ideal broadcast will end up running these steps on the 3D CuArrays, say, As_data, Bs_data, Cs_data that back the batched type:

function update_batched(bar::Batched{Bar})
    # ...
    new_Cs_data = CUDA.CUBLAS.gemm_strided_batched(As_data, Cs_data)
    new_Bs_data = CUDA.CUBLAS.gemm_strided_batched(Bs_data, new_Cs_data)
    # return ...
end

Question

I appreciate that this is quite a large and broad question, so I’m not expecting a full solution.

Instead, I would appreciate any pointers to packages/projects that have tackled this or similar problems already.

Otherwise, some high-level suggestions as to what the cleanest and jank-free way of achieving this, would be greatly appreciated.

In some sense this is similar to how NumPy handles broadcasting. When using, e.g. TensorFlow, I can call a function update on 3D arrays and it will act like I’m calling it on individual matrices, batching the remaining dimensions automatically in an efficient fashion. Can this behaviour be replicated in Julia in a more principled/idiomatic fashion?

Initial Thoughts

You can probably ignore this section but I thought it was worth also mentioning someone in this post the ideas that I have.

Idea 1: Fake AbstractMatrix type (essentially what NumPy does)

Suppose on top of our batched types, we have type BroadcastedMatrix that is backed by a 3D CuArray but is a subtype of AbstractMatrix. This is philosophically wrong, but it allows us to instantiate, e.g, a `Foo(::BroadcastedMatrix, ::BroadcastedMatrix).

I’ve created some functions that recurses through a Batched{ET} and convert it to a ET where all of the matrices inside are BroadcastedMatrixs. You can then just call update (without broadcastingon thisETand define some dispatching rule for how, e.g,*(::BroadcastedMatrix, ::BroadcastedMatrix)` should behave.

This actually works surprisingly well, but I just feel that subtyping from AbstractMatrix with something that inherently is a 3D array (even though it kind of acts like a matrix in this context) is a bad idea and will lead to undefined behaviour.

This solution feels like the closest to replicating exactly how NumPy broadcasting works.

Idea 2: IR Manipulation

It roughly feels that when we call update. on a batched foo, we would like the broadcast to be interrupted and converted into a function that looks something like this

function batched_update(bar::Batched{Bar{T}}) where {T}
     # Extract backing CuArrays assuming a structure like StructArrays
    A = bar.data.foo.data.A
    B = bar.data.foo.data.B
    C = bar.data.C

    # Perform operations by calling CuBLAS wrappers
    new_C = batch_call(*, A, C)
    new_B = batch_call(*, B, new_C)

    # The Batched type would use StructArrays for creating composite structs
    return Batched{Bar{...}}(
        Batched{CuMatrix{T}}(new_C), 
        Batched{Foo}{...}(Batched{CuMatrix{T}}(A), Batched{CuMatrix{T}}(B_new))
    )
end

This transformation seems reasonably simple to describe and could potentially be generated through some IR manipulation of the original update function. I worry that this is massively overkill though and could lead to unexpected behaviour.

1 Like

The easiest way would be to use Reactant.jl to optimize with MLIR same way its done in Jax,Pytorch ect. Another one would be to use StructArrays.jl to convert the inner representation.

1 Like

actually seems like its not able to fuze anything because of the representation funny to see,

struct Foo{T1,T2}
    A::T1
    B::T2
end

struct Bar{T1,T2}
    C::T1
    foo::T2
end

function update(bar)
    A = bar.C*bar.foo.A
    B = bar.C*bar.foo.B
    return Bar(bar.C, Foo(A,B))
end
function updates(bars)
    return update.(bars)
end

A = rand(100,100,100)
B = rand(100,100,100)
C = rand(100,100,100)
bars = [Bar(C[:,:,i], Foo(A[:,:,i],B[:,:,i])) for i in 1:100]

using Reactant
Reactant.set_default_backend("gpu")
bars_r = Reactant.to_rarray(bars)
up_c = @compile updates(bars_r)
@code_xla updates(bars_r)

leads to an insane amount of calls and with 1000 batch it just takes forever to run

the problem in your code snippet above is that you have a bunch of different allocations for all of those. If you intead constructed bars with @view instead, it would at least make it possible to do batched calls (though not guarantee its done).

Doesn’t help, why would keeping views help, isn’t the memory completly relocated when going on the gpu, knowing all those come from a single CPU big array isn’t that valuable I think

No, to_rarray preseves aliasing. So it should retain the fact there is one big GPU array, with many little views

The dot_general batch fusion hasn’t landed yet feat: general concat reshape to batching pass by avik-pal · Pull Request #983 · EnzymeAD/Enzyme-JAX · GitHub

Thanks all for the suggestions! I wasn’t aware of Reactant but it looks very powerful.

@avikpal Can I check my understanding of this PR?

I’ve used XLA in the past through JAX to try to fuse a sequence of batch linear operations (vmaping over a function like update). In a low dimension/large batch setting (memory bound), I found that XLA produced code that was close to theoretically optimal throughput when using only matmuls, but was far slower when other linear algebra operations such as Choleskys were introduced.

Is the point of this PR just to expose the Julia dot broadcasting syntax to XLA, rather than make any improvement to how XLA handles such operations? I.e. would the introduction of this PR make sequential batched linear algebra any faster than when I tried XLA through JAX.

The PR (and some similar functionality we have for cholesky – though that’s not merged fully either yet) adds general optimizations in EnzmyeJaX that apply to general MLIR/XLA. They will be on by default in Reactant and you can also use them from EnzymeJaX in JaX if you like

I have a much simpler but related question (Does not work Reactant arrays · Issue #52 · JuliaArrays/ArraysOfArrays.jl · GitHub)

Basically, this package implements efficient (performance and memory) physical representation of ragged arrays:

julia> dump(VectorOfVectors([[1,2,3], [4,5], [7]]))
VectorOfVectors{Int64, Vector{Int64}, Vector{Int64}, Vector{Tuple{}}}
  data: Array{Int64}((6,)) [1, 2, 3, 4, 5, 7]
  elem_ptr: Array{Int64}((4,)) [1, 4, 6, 7]

but currently one can’t move them to Reactant arrays – what’s the recommended way to make it happen? And how to “teach” reactant or what not how to efficiently operate on these?