Telling Julia the type of a function's output (and other adventures in type stability)

Hey all. I’m having a bit of a hard time understanding types and how to deal with them. Take this small example involving Lux Networks:

using Lux, Random, JET

input = 1
n = 12
act = tanh
model = 
    Chain(Dense(input => n, act),
          Dense(n => n, act),
          Dense(n => 1), first)
rng = Random.default_rng()
p0, s0 = Lux.setup(rng,model)
x0 = rand(Float32)

u(x,p,st) = x*(model([x],p,st)|>first)

Now, if I run @code_warntype or JET.@report_opt on model(x0,p0,s0), they tell me all is fine. However, I cannot make u, as simple as it is, type stable:

Arguments
  #self#::Core.Const(u)
  x::Float32
  p::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(), Tuple{}}}}
  st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Body::Any
1 ─ %1 = Base.vect(x)::Vector{Float32}
│   %2 = Main.model(%1, p, st)::Any
│   %3 = (%2 |> Main.first)::Any
│   %4 = (x * %3)::Any
└──      return %4

Thus, every other calculation applied to u downstream will also be typed as Any, which should make the code very slow (right?)

The issue, it seems, is that Julia cannot parse the type of the output of model from the types of its inputs, probably because the parameter and state vectors are these large nested structures. That’s ok by me, but I know it. I could, for example, determine this type by running model once on a given set of inputs and then somehow tell the compiler that that’s always going to be the output’s type. However, this doesn’t seem to be possible.

My questions, then, are the following:
1: How does one deal with the types of function outputs? What should I do when the compiler doesn’t recognize the output’s type for a function I did not write?
2: Should I really be stressed over this? Does this kind of type stability actually matter for performance? Are there any quick mitigating tricks?

Non-const globals are bad for performance. Try

const model = 
    Chain(Dense(input => n, act),
          Dense(n => n, act),
          Dense(n => 1), first)
1 Like

Not really. The problem is that, since model is a global variable, at any time you can change it to mean anything else and the function u Will have to keep track of it. Therefore “julia” cannot make any prediction of what calling it might return in the future, although it certainly can deduce the type of the Chain(.. object the variable model currently points to.

The comment above fixes the issue. Another approach is to use a let block as

let const_model=model
   #We define u as global because we want it to be available outside this let scope
    global u(x,p,st) = x(const_model([x],p,st) |> first)
end

which makes a local copy of whatever the variable model holds at the moment of the let block evaluation.

2 Likes

Oh, I see. So this principle applies even for model a function.

Just to be clear, then: The issue here is that the tag model was applied to the function returned by Chain; When declaring a function the normal way foo(x)=x this is not necessary, correct? Even though this definition also lives in the global scope?

Correct, function names are always constant and cannot be changed:

julia> my_func(x) = x
my_func (generic function with 1 method)

julia> my_func = 10
ERROR: invalid redefinition of constant my_func
Stacktrace:
 [1] top-level scope
   @ REPL[2]:1

I don’t know the internals of Lux, but mostly likely the Chain(.. call returns an object, not an actual function. But it doesn’t matter, as functions are also objects in julia and you can attribute different labels to it (but you cannot change it’s name).

julia> f = sin
sin (generic function with 14 methods)

julia> f(π/2)
1.0

julia> f = cos
cos (generic function with 14 methods)

julia> f(π/2)
6.123233995736766e-17

julia> sin = cos
ERROR: cannot assign a value to imported variable Base.sin from module Main
Stacktrace:
 [1] top-level scope
   @ REPL[10]:1

Fantastic, thank you very much.