The Flux.stack
function is defined here: https://github.com/FluxML/Flux.jl/blob/b78a27b01c9629099adb059a98657b995760b617/src/utils.jl#L476, and it is very simple:
stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
.
However, its implementation seems contrary to the “Flux Performance Tips” here: Performance Tips · Flux, specifically " When doing this kind of concatenation use reduce(hcat, xs)
rather than hcat(xs...)
. This will avoid the splatting penalty, and will hit the optimised reduce
method."
Is there a reason for why Flux.stack
uses splatting rather than reduce
?