Batch/batch in Flux?

I am building my own ODE layer and would like to make sure it works with batches as well as with individual vectors.
The input and output to my layer are a NxM 2D arrays, so I will also need to reshape to couple e.g. to a Dense layer.

I am not sure how to make this work for batches.

The v0.1 Flux documentation (strangely, this is what a search for “Flux julia batch” returns first) hints at a Batch interface with a very useful Flux.Batch struct:

julia> xs = Batch([[1,2,3], [4,5,6]])
2-element Batch of Vector{Int64}:
 [1,2,3]
 [4,5,6]

I did not find this in the latest source, but utils.jl has a function batch that just batches into an array along the leftmost index.

julia> xs = Flux.batch([[1,2,3], [4,5,6]])
3×2 Array{Int64,2}:
 1  4
 2  5
 3  6

This function just returns an Array{Int64,2} so my layer can’t identify this as a batch.

  • Has Batch been replaced by batch? Is batch the correct way to do batching?

  • How can I write a generic layer that works with and without batches?

  • How can I make sure that reshaping operations work transparently, i.e. if I couple a Dense layer to a MyLayer that accepts a 10x10 2D array as an input:
    m=Chain(Dense(100,100), x->reshape(x, (10,10)), MyLayer((10,10)))?
    m should work both on an Array with 100 points and on a batch of such vectors.

How can I write a generic layer that works with and without batches?

The simplest solution is to just the data without a batch to have a batch size of 1.

How can I make sure that reshaping operations work transparently, i.e. if I couple a Dense layer to a MyLayer that accepts a 10x10 2D array as an input:
m=Chain(Dense(100,100), x->reshape(x, (10,10)), MyLayer((10,10))) ?
m should work both on an Array with 100 points and on a batch of such vectors.

If you treat the “unbatched” data to have batch size of 1 then m=Chain(Dense(100,100), x->reshape(x, (10,10, :)), MyLayer((10,10))) should just work (assuming MyLayer works for batched data.

1 Like

Thank you! I will change the code to work with batch size one.