Batch/batch in Flux?

How can I write a generic layer that works with and without batches?

The simplest solution is to just the data without a batch to have a batch size of 1.

How can I make sure that reshaping operations work transparently, i.e. if I couple a Dense layer to a MyLayer that accepts a 10x10 2D array as an input:
m=Chain(Dense(100,100), x->reshape(x, (10,10)), MyLayer((10,10))) ?
m should work both on an Array with 100 points and on a batch of such vectors.

If you treat the “unbatched” data to have batch size of 1 then m=Chain(Dense(100,100), x->reshape(x, (10,10, :)), MyLayer((10,10))) should just work (assuming MyLayer works for batched data.

1 Like