Hello,
Flux.jl based models are usually evaluated in a batch-wise fashion for performance reasons, i.e.
xs_batch = rand(10, 64); # 64 ten-dimensional features
ys_batch = model(xs_batch)
@assert size(ys_batch) == (1, 64)
For cleanly writing a pseudo-code implementation, I would like to define xs
as an AbstractVector
of samples which lie continuously in memory, and then write model.(xs)
. The overloaded implementation should then automatically batch the samples, compute the results batch-by-batch and then unbatch them again to return a AbstractVector{Output_t}
.
# I want this
xs = [rand(10) for _ in 1:64]; # or better `rand(10, 64) |> unbatch`
ys = model.(xs) # automatically batches
@assert length(ys) == 64
I was able to redefine the broadcasting rule for my model (see below), and model.(xs)
works.
But when defining another function f(x) = 2*model(x)
and then computing f.(xs)
, the redefined broadcasting rule doesn’t get used anymore…
f(x) = 2*model(x)
f.(xs) # this doesn't use my custom batching anymore :(
Here’s a MWE including the redefined broadcasting for my model.
using Flux
using MLUtils
using Transducers
Input_t = Vector{Float32}
Output_t = Vector{Float32}
data = rand(Float32, 3, 20) |> unbatch; # makes vector of slices
model = Chain(Dense(3, 100), Dense(100, 100), Dense(100, 2));
Broadcast.broadcasted(model_::typeof(model), xs::AbstractVector{Input_t}) = begin
batchsize = 16
model_batched(xs :: AbstractVector{Input_t}) :: AbstractVector{Output_t} = begin
@info length(xs)
return model_(xs |> batch) |> unbatch
end
return collect(xs |> Partition(batchsize; flush=true) |> MapCat(model_batched)) # or tcollect/dcollect
end
model.(data) # works (!), prints the @info
f(x) = 2*model(x)
f.(data) # also "works", but doesn't use the provided broadcasting implementation
Any pointers how to make f
also use the broadcasting rule would be appreciated!
Cheers