Good question. I wasn’t around when Flux was originally written. But I assume it has to do with how matrix multiplication works in julia and Flux’s overall goal of being just normal looking Julia.
Imagine you have a chain
julia> model = chain(Dense(2, 10), RNN(10, 1))
The chain is just a convenient operator for chaining information forward. The dense layer is very simple under the hood resulting in this operation for an input vector x (you can find this for 0.11.6 here.
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
Now if we use 3d-arrays, this doesn’t work because W (which is a matrix) and x (which is now a 3d-array) don’t have default behavior for multiplication, and in fact there are many ways to do this where it doesn’t make sense for * to have default behavior. To make this work w/ 3-d arrays in the most “julian” way (at least by my reckoning) we would dispatch on the number of dimensions the array has and specialize. This would cause every layer (i.e. dense, recurrent, conv, etc…) to need to have a special function for these 3d arrays. Adding to the code complexity of Flux. What this function would look like is almost exactly (i think) what I wrote above, but now for each layer in the network individually. In fact what I wrote above should be just as fast as anything you would write here because for-loops are fast in Julia, so there isn’t much optimization to be had, and as far as I know there aren’t any specialized blas operations for this.
Instead, it makes more sense to use the broadcast or map functionality. This cuts down on specialized code in Flux, and makes running a model seem very julian/intuitve. There is no confusion about what the broadcast is doing, but in the 3d-array version now I have to check which strand is the time index.
It may seem more convenient to have support for 3d-arrays because this is what Tensorflow/Pytorch do. But they have to do this for efficiency as python for loops can be slow, so they just do the for loop in c instead. This isn’t a bad thing. And actually quite clever. It is just that Julia doesn’t have that constraint. So why design like we do.