Hello,
Flux.jl based models are usually evaluated in a batch-wise fashion for performance reasons, i.e.
xs_batch = rand(10, 64); # 64 ten-dimensional features
ys_batch = model(xs_batch)
@assert size(ys_batch) == (1, 64)
For cleanly writing a pseudo-code implementation, I would like to define xs as an AbstractVector of samples which lie continuously in memory, and then write model.(xs). The overloaded implementation should then automatically batch the samples, compute the results batch-by-batch and then unbatch them again to return a AbstractVector{Output_t}.
# I want this
xs = [rand(10) for _ in 1:64]; # or better `rand(10, 64) |> unbatch`
ys = model.(xs) # automatically batches
@assert length(ys) == 64
I was able to redefine the broadcasting rule for my model (see below), and model.(xs) works.
But when defining another function f(x) = 2*model(x) and then computing f.(xs), the redefined broadcasting rule doesn’t get used anymore…
f(x) = 2*model(x)
f.(xs) # this doesn't use my custom batching anymore :(
Here’s a MWE including the redefined broadcasting for my model.
using Flux
using MLUtils
using Transducers
Input_t = Vector{Float32}
Output_t = Vector{Float32}
data = rand(Float32, 3, 20) |> unbatch; # makes vector of slices
model = Chain(Dense(3, 100), Dense(100, 100), Dense(100, 2));
Broadcast.broadcasted(model_::typeof(model), xs::AbstractVector{Input_t}) = begin
batchsize = 16
model_batched(xs :: AbstractVector{Input_t}) :: AbstractVector{Output_t} = begin
@info length(xs)
return model_(xs |> batch) |> unbatch
end
return collect(xs |> Partition(batchsize; flush=true) |> MapCat(model_batched)) # or tcollect/dcollect
end
model.(data) # works (!), prints the @info
f(x) = 2*model(x)
f.(data) # also "works", but doesn't use the provided broadcasting implementation
Any pointers how to make f also use the broadcasting rule would be appreciated!
Cheers