Type-inference fails then succeeds for exact same function

I’ve been facing multiple issues with type-inference, specially when using ComponentArrays. In particular, it may succeed or fail, apparently at random, for the same function, in the same program. What could be causing this? Does type-inference depend on any sort of state which could be changing throughout the script?

This is easiest to reproduce using Lux Chains of Custom Layers. Take the following example, where DummyLayer is just some wrapper for an actual NN layer:

using Lux, ComponentArrays, Random, Cthulhu

struct DummyLayer{T} <: Lux.AbstractExplicitLayer
    layer::T
end 

function (layer::DummyLayer)(X,ps,st) 
    fx, _ = layer.layer(X,ps,st)
    return fx, ()
end

Lux.initialparameters(rng::AbstractRNG,layer::DummyLayer) = Lux.initialparameters(rng::AbstractRNG,layer.layer)
Lux.initialstates(rng::AbstractRNG,layer::DummyLayer) = Lux.initialstates(rng::AbstractRNG,layer.layer)

n = 4
X = rand(n)
width = 12 

C = Chain(Dense(n=>width),Dense(width=>n))

D = DummyLayer(C)
DD = Chain(D,D)
ps, st = Lux.setup(Random.default_rng(),DD)
psc = ps |> ComponentArray

DD(X,psc,st)
@code_warntype DD(X,psc,st)

The last line outputs the following (note the Tuple{Any,...} in %2):

code_warntype
MethodInstance for (::Chain{NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}, Nothing})(::Vector{Float64}, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:112, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))), layer_2 = ViewAxis(113:224, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))))}}}, ::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}})
  from (c::Chain)(x, ps, st::NamedTuple) @ Lux C:\Users\55619\.julia\packages\Lux\3Kn7l\src\layers\containers.jl:478
Arguments
  c::Chain{NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}, Nothing} 
  x::Vector{Float64}
  ps::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:112, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))), layer_2 = ViewAxis(113:224, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))))}}}
  st::Core.Const((layer_1 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))
Body::Tuple{Any, NamedTuple{(:layer_1, :layer_2), Tuple{Tuple{}, Tuple{}}}}
1 ─ %1 = Base.getproperty(c, :layers)::NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}
│   %2 = Lux.applychain(%1, x, ps, st)::Tuple{Any, NamedTuple{(:layer_1, :layer_2), Tuple{Tuple{}, Tuple{}}}}
└──      return %2

However, if I now redeclare the exact same function in the subsequent lines, the compiler is able to infer all the types and the call becomes type-stable (!):

function (layer::DummyLayer)(X,ps,st) 
    fx, _ = layer.layer(X,ps,st)
    return fx, ()
end

DD(X,psc,st)
@code_warntype DD(X,psc,st)
code_warntype
MethodInstance for (::Chain{NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}, Nothing})(::Vector{Float64}, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:112, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))), layer_2 = ViewAxis(113:224, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))))}}}, ::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}})
  from (c::Chain)(x, ps, st::NamedTuple) @ Lux C:\Users\55619\.julia\packages\Lux\3Kn7l\src\layers\containers.jl:478
Arguments
  c::Chain{NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}, Nothing} 
  x::Vector{Float64}
  ps::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:112, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))), layer_2 = ViewAxis(113:224, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))))}}}
  st::Core.Const((layer_1 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))
Body::Tuple{Vector{Float64}, NamedTuple{(:layer_1, :layer_2), Tuple{Tuple{}, Tuple{}}}}
1 ─ %1 = Base.getproperty(c, :layers)::NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}
│   %2 = Lux.applychain(%1, x, ps, st)::Tuple{Vector{Float64}, NamedTuple{(:layer_1, :layer_2), Tuple{Tuple{}, Tuple{}}}}
└──      return %2

(Julia 1.9, VSCode, Windows)

Why should inference work for the second time but not the first? Did I perhaps give the compiler more information by redeclaring the function after running it?

Likewise, is it possible to bypass inference when it stops working? This has been a huge source of headaches; it would be nice to be able to just tell the compiler what the output of a function is going to be.

1 Like

I think this is a known issue, these two Github issues seem like they could be relevant:

2 Likes

Type assertions are this, when applied to the return value of the function. For example; do return expr::T instead of return expr.

1 Like