Hey all. I’m having a bit of a hard time understanding types and how to deal with them. Take this small example involving Lux Networks:
using Lux, Random, JET
input = 1
n = 12
act = tanh
model =
Chain(Dense(input => n, act),
Dense(n => n, act),
Dense(n => 1), first)
rng = Random.default_rng()
p0, s0 = Lux.setup(rng,model)
x0 = rand(Float32)
u(x,p,st) = x*(model([x],p,st)|>first)
Now, if I run @code_warntype
or JET.@report_opt
on model(x0,p0,s0)
, they tell me all is fine. However, I cannot make u, as simple as it is, type stable:
Arguments
#self#::Core.Const(u)
x::Float32
p::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(), Tuple{}}}}
st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Body::Any
1 ─ %1 = Base.vect(x)::Vector{Float32}
│ %2 = Main.model(%1, p, st)::Any
│ %3 = (%2 |> Main.first)::Any
│ %4 = (x * %3)::Any
└── return %4
Thus, every other calculation applied to u
downstream will also be typed as Any
, which should make the code very slow (right?)
The issue, it seems, is that Julia cannot parse the type of the output of model
from the types of its inputs, probably because the parameter and state vectors are these large nested structures. That’s ok by me, but I know it. I could, for example, determine this type by running model
once on a given set of inputs and then somehow tell the compiler that that’s always going to be the output’s type. However, this doesn’t seem to be possible.
My questions, then, are the following:
1: How does one deal with the types of function outputs? What should I do when the compiler doesn’t recognize the output’s type for a function I did not write?
2: Should I really be stressed over this? Does this kind of type stability actually matter for performance? Are there any quick mitigating tricks?