I’ve been facing multiple issues with type-inference, specially when using ComponentArrays
. In particular, it may succeed or fail, apparently at random, for the same function, in the same program. What could be causing this? Does type-inference depend on any sort of state which could be changing throughout the script?
This is easiest to reproduce using Lux Chains of Custom Layers. Take the following example, where DummyLayer
is just some wrapper for an actual NN layer:
using Lux, ComponentArrays, Random, Cthulhu
struct DummyLayer{T} <: Lux.AbstractExplicitLayer
layer::T
end
function (layer::DummyLayer)(X,ps,st)
fx, _ = layer.layer(X,ps,st)
return fx, ()
end
Lux.initialparameters(rng::AbstractRNG,layer::DummyLayer) = Lux.initialparameters(rng::AbstractRNG,layer.layer)
Lux.initialstates(rng::AbstractRNG,layer::DummyLayer) = Lux.initialstates(rng::AbstractRNG,layer.layer)
n = 4
X = rand(n)
width = 12
C = Chain(Dense(n=>width),Dense(width=>n))
D = DummyLayer(C)
DD = Chain(D,D)
ps, st = Lux.setup(Random.default_rng(),DD)
psc = ps |> ComponentArray
DD(X,psc,st)
@code_warntype DD(X,psc,st)
The last line outputs the following (note the Tuple{Any,...}
in %2):
code_warntype
MethodInstance for (::Chain{NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}, Nothing})(::Vector{Float64}, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:112, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))), layer_2 = ViewAxis(113:224, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))))}}}, ::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}})
from (c::Chain)(x, ps, st::NamedTuple) @ Lux C:\Users\55619\.julia\packages\Lux\3Kn7l\src\layers\containers.jl:478
Arguments
c::Chain{NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}, Nothing}
x::Vector{Float64}
ps::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:112, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))), layer_2 = ViewAxis(113:224, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))))}}}
st::Core.Const((layer_1 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))
Body::Tuple{Any, NamedTuple{(:layer_1, :layer_2), Tuple{Tuple{}, Tuple{}}}}
1 ─ %1 = Base.getproperty(c, :layers)::NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}
│ %2 = Lux.applychain(%1, x, ps, st)::Tuple{Any, NamedTuple{(:layer_1, :layer_2), Tuple{Tuple{}, Tuple{}}}}
└── return %2
However, if I now redeclare the exact same function in the subsequent lines, the compiler is able to infer all the types and the call becomes type-stable (!):
function (layer::DummyLayer)(X,ps,st)
fx, _ = layer.layer(X,ps,st)
return fx, ()
end
DD(X,psc,st)
@code_warntype DD(X,psc,st)
code_warntype
MethodInstance for (::Chain{NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}, Nothing})(::Vector{Float64}, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:112, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))), layer_2 = ViewAxis(113:224, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))))}}}, ::NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}})
from (c::Chain)(x, ps, st::NamedTuple) @ Lux C:\Users\55619\.julia\packages\Lux\3Kn7l\src\layers\containers.jl:478
Arguments
c::Chain{NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}, Nothing}
x::Vector{Float64}
ps::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:112, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))), layer_2 = ViewAxis(113:224, Axis(layer_1 = ViewAxis(1:60, Axis(weight = ViewAxis(1:48, ShapedAxis((12, 4), NamedTuple())), bias = ViewAxis(49:60, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(61:112, Axis(weight = ViewAxis(1:48, ShapedAxis((4, 12), NamedTuple())), bias = ViewAxis(49:52, ShapedAxis((4, 1), NamedTuple())))))))}}}
st::Core.Const((layer_1 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))
Body::Tuple{Vector{Float64}, NamedTuple{(:layer_1, :layer_2), Tuple{Tuple{}, Tuple{}}}}
1 ─ %1 = Base.getproperty(c, :layers)::NamedTuple{(:layer_1, :layer_2), Tuple{DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}, DummyLayer{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}}}}
│ %2 = Lux.applychain(%1, x, ps, st)::Tuple{Vector{Float64}, NamedTuple{(:layer_1, :layer_2), Tuple{Tuple{}, Tuple{}}}}
└── return %2
(Julia 1.9, VSCode, Windows)
Why should inference work for the second time but not the first? Did I perhaps give the compiler more information by redeclaring the function after running it?
Likewise, is it possible to bypass inference when it stops working? This has been a huge source of headaches; it would be nice to be able to just tell the compiler what the output of a function is going to be.