Why does Flux.Chain creates Exprs to get type stability?

Hi. Upon trying to understand how Flux.Chains evaluates inputs, I saw the two following definitions:

@generated function _applychain(layers::Tuple{Vararg{Any,N}}, x) where {N}
  symbols = vcat(:x, [gensym() for _ in 1:N])
  calls = [:($(symbols[i+1]) = layers[$i]($(symbols[i]))) for i in 1:N]
  Expr(:block, calls...)
end

and

function _applychain(layers::AbstractVector, x)  # type-unstable path, helps compile times
  for f in layers
    x = f(x)
  end
  x
end

The second one feels type unstable but I don’t know a concrete reason why, but I’m at a loss for the first one. Can anyone give insight as to why the first one is better (unrolling?), and why the second specifically is type unstable.

Thanks.

Can be found here: https://github.com/FluxML/Flux.jl/blob/ccf87bb13f01ff0c0a1a08d900f0a9d8c9122da3/src/layers/basic.jl#L53

This was added in this PR, because it was faster (both on the first run, and once all compiled) than the previous version – times here. I think both were type-stable, but this one more Zygote-friendly.

A Vector of different layers has to have abstract eltype, which will certainly make that method unstable. This is normally a bad thing, but sometimes it saves a lot of compile time, for small runtime cost.

1 Like