Why does Flux.Chain creates Exprs to get type stability?

Hi. Upon trying to understand how Flux.Chains evaluates inputs, I saw the two following definitions:

@generated function _applychain(layers::Tuple{Vararg{Any,N}}, x) where {N}
  symbols = vcat(:x, [gensym() for _ in 1:N])
  calls = [:($(symbols[i+1]) = layers[$i]($(symbols[i]))) for i in 1:N]
  Expr(:block, calls...)


function _applychain(layers::AbstractVector, x)  # type-unstable path, helps compile times
  for f in layers
    x = f(x)

The second one feels type unstable but I don’t know a concrete reason why, but I’m at a loss for the first one. Can anyone give insight as to why the first one is better (unrolling?), and why the second specifically is type unstable.


Can be found here: https://github.com/FluxML/Flux.jl/blob/ccf87bb13f01ff0c0a1a08d900f0a9d8c9122da3/src/layers/basic.jl#L53

This was added in this PR, because it was faster (both on the first run, and once all compiled) than the previous version – times here. I think both were type-stable, but this one more Zygote-friendly.

A Vector of different layers has to have abstract eltype, which will certainly make that method unstable. This is normally a bad thing, but sometimes it saves a lot of compile time, for small runtime cost.

1 Like