Broadcasting Flux.jl model over vectors

Flux.jl based models are usually evaluated in a batch-wise fashion for performance reasons, i.e.

xs_batch = rand(10, 64);  # 64 ten-dimensional features
ys_batch = model(xs_batch)
@assert size(ys_batch) == (1, 64)

For cleanly writing a pseudo-code implementation, I would like to define xs as an AbstractVector of samples which lie continuously in memory, and then write model.(xs). The overloaded implementation should then automatically batch the samples, compute the results batch-by-batch and then unbatch them again to return a AbstractVector{Output_t}.

# I want this
xs = [rand(10) for _ in 1:64];  # or better `rand(10, 64) |> unbatch`
ys = model.(xs)  # automatically batches
@assert length(ys) == 64

I was able to redefine the broadcasting rule for my model (see below), and model.(xs) works.
But when defining another function f(x) = 2*model(x) and then computing f.(xs), the redefined broadcasting rule doesn’t get used anymore…

f(x) = 2*model(x)
f.(xs)  # this doesn't use my custom batching anymore :(

Here’s a MWE including the redefined broadcasting for my model.

using Flux
using MLUtils
using Transducers

Input_t = Vector{Float32}
Output_t = Vector{Float32}
data = rand(Float32, 3, 20) |> unbatch;  # makes vector of slices
model = Chain(Dense(3, 100), Dense(100, 100), Dense(100, 2));

Broadcast.broadcasted(model_::typeof(model), xs::AbstractVector{Input_t}) = begin
    batchsize = 16
    model_batched(xs :: AbstractVector{Input_t}) :: AbstractVector{Output_t} = begin
            @info length(xs)
            return model_(xs |> batch) |> unbatch
    return collect(xs |> Partition(batchsize; flush=true) |> MapCat(model_batched))  # or tcollect/dcollect

model.(data)  # works (!), prints the @info
f(x) = 2*model(x)
f.(data)  # also "works", but doesn't use the provided broadcasting implementation

Any pointers how to make f also use the broadcasting rule would be appreciated!


To make f use the broadcasting rule you would need it to be the same type, and if you want that to work in general for arbitrary functions then you probably want to dispatch on any callable with your specific data type. This could be done, though I don’t think it is a good idea since it might affect other code in unexpected ways.

It seems like unbatch copies the original data, so if you wanted that preserved without copying you might want to create your own implementation for that.
But then you might also just create a wrapper type for the data and dispatch the broadcast on that type instead of the model type maybe? Seems easier if you dont want to keep the model type constant.

Though I can’t really say I see how this becomes that much cleaner than just being clear about having a matrix where the batching is in one of the dimensions, just seems a bit messier with having to do this broadcast dispatch in some reasonable way :sweat_smile:

If I understand your use case correctly, this would fit very well the idea behind my library JJ.jl (inspired by APL/J). Using it, you could define your model/function for a vector input and apply it over a batch by giving it rank of one, i.e.,

using JJ

model = Dense(10, 2)
f(x::AbstractVector) = 2 .* model(x) 

batch = randn(10, 64)
rank"f 1"(batch) |> size  # would give (2, 64)

This gives indeed very clean code also for more complicated models, e.g., see my example of a simple transfomer layer.

Internally, this builds on JuliennedArrays which you could also use directly, if a more explicit implementation is desired, i.e.,

using JuliennedArrays
Align(f.(Slices(batch, 1)), 1)

where Slices corresponds to your unbatch and Align reconstructs an array from an array of arrays. In any case, some other libraries such as SplitApplyCombine or the functions eachslice and stack in Julia 1.9 (nightly) provide similar functionalities.

1 Like

Regarding the reason why I believe this would make the code more clean, consider testing a property of the model for each sample - e.g. check that the the prediction error is less than 0.5.
We can define a function

function test_abs_error_less_than(x, y, abs_err)
    y_pred = model(x)
    return abs(y - y_pred) <= abs_err

Then, I would like to execute

percent_passed = mean(test_abs_error_less_than.(xs, ys, 0.5))

This is imo extremely concise and skips any fumbling with defining some dataloader over the xs and so on.
Note that often we cannot compute model(xs_full_batch) because we run out of GPU memory.

Before playing around with this, I thought that broadcasting the test_abs_error_less_than function essentially defined a function like this

function test_abs_error_less_than.(xs, y,s abs_err)
    ys_pred = model.(xs)
    return abs.(ys .- ys_pred) .<= abs_err

i.e. with all functions transformed to their broadcasted equivalent. But I understand now that this is not the case? This would go against my observations on a simplified test:

using Test
g(x) = 2*sin(x)
xs = [1., 2., 3.]
@test_throws MethodError g(xs)
g.(xs)  # this works

I assume that in the call to g.(xs), we implicitly define a function like g.(xs) = 2.*sin.(x). But this seems to be different to my model example above.

Thanks, this looks very interesting! Perhaps the JuliennedArrays can take care of the extra copies when calling batch and unbatch. JJ.jl also looks very interesting, although I am not very familiar with the J language, so I will have to look into it a bit.

The eachslice and stack functions are actually a good alternative to the batch and unbatch functions from MLUtils. In my version of julia (1.6), batch is significantly faster than stack though. I couldn’t find any details on the changes you mentioned for julia 1.9.

The changes for 1.9 are discussed here: Add `stack(array_of_arrays)` by mcabbott · Pull Request #43334 · JuliaLang/julia · GitHub

Guess your model of broadcasting might work, but it is not complete, i.e., it explains broadcasting in terms of more broadcasting. I would rather think of it as being a for-loop, i.e.,

A = randn(2, 3)
g(x) = 2*sin(x)

[g(A[idx]) for idx in CartesianIndices(A)]  # roughly same as g.(A)

In any case, I had misunderstood your use case. model.(data) would handle all inputs sequentially and not take advantage of mini-batching. To that end, I would probably just define a small function instead of trying to overload the broadcasting syntax:

function minibatched(model, xs::AbstractVector{Input_t}; batchsize = 16)
    xs |> Partition(batchsize; flush = true) |> MapCat(xs -> model(xs |> batch) |> unbatch) |> collect

minibatched(model, data)  # instead of model.(data)

If you urgently need a more concise syntax, you could probably overload one of the currently unused symbols that get parsed as infix operators (see julia - User-defined infix operator - Stack Overflow):

⋄(m, xs) = minibatched(m, xs)
model ⋄ data