Help needed: Building an intuitive understanding of closures, scoping and type-stability

Even after quite some time, I still struggle with understanding how to create type-stable wrappers around simple functions. I’ve read the docs, I know that closures in julia use reference semantics instead of value semantics (and that Jeff regrets this choice), but I just can’t wrap my head around it.

A couple of silly examples to try and illustrate my confusion. It makes intuitive sense that the type of a global variable may change under the compiler’s nose, so we have to set it as a constant to have a type-stable closure:

f(x,y) = x+y 
@code_warntype f(2,2) #Type-stable 
k = 2
f_k(x) = f(x,k)
@code_warntype f_k(2) #Not type-stable
const k_const = k
f_k_const(x) = f(x,k_const)
@code_warntype f_k_const(2)#Type-stable

However, sometimes this does not seem to be enough:

using Lux, Random

C = Dense(4=>4,relu) #Simple neural network, with parameters and states
ps,st = Lux.setup(Random.default_rng(),C)
x = rand(4)

@code_warntype C(x,ps,st) #Type-stable
f(x,p) = C(x,p,st) #Don't want to think about states, give it a default value
@code_warntype f(x,ps) #Type-unstable
const st_const = st #Maybe if we make st a const, it will be type-stable? 
g(x,p) = C(x,p,st_const) 
@code_warntype g(x,ps) #No :(

The only way I’m able to make this work is with strong scoping by means of let statements:

i = let C = C, st = st
    (x,p) -> C(x,p,st)
@code_warntype i(x,ps)# Type-stable

This is a silly amount of code for such a simple task, and it’s quite burdensome to write this down every time I want to build a closure.

Can anyone share a mental model that explains how these things work? It’s been hard for me to “think like a compiler”.

As far as I can tell, the let statements kind of enforce value semantics by means of copies, which is what I want basically 100% of the time. Can’t I achieve the same thing without the boilerplate?

Any insights would be highly appreciated.


Thank you for posting an executable code example.

The issue in your second example has a variant of the same issue you resolved in the first example. The variable C is not constant.

For type inference, we and the compiler need to prove that given a function and set of types that we can determine the types of all the variables and the return type. Variables in the global scope frustrate type inference since the bindings that the variables refer to can change at any time. Importantly, they can also change type which would, in turn, affect type inference. A local scope created by a function or a let block prevents the type of binding from being changed from outside of the block. Now that we can make proper assumptions about the type of the variables, we can do proper type inference.

Because of the issues with global scope, we highly encourage people to work with local scopes and not the global scope. Generally, this means putting most of your code into functions that do not use globals. For a situation like this, consider a generator function such as the following.

julia> function make_h()
           C = Dense(4=>4, relu)
           ps, st = Lux.setup(Random.default_rng(),C)
           h(x) = C(x, ps, st)
           return h
make_h (generic function with 1 method)

julia> const h = make_h()
(::var"#g#7"{NamedTuple{(), Tuple{}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}}) (generic function with 1 method)

julia> @code_warntype h(x)
MethodInstance for (::var"#g#7"{NamedTuple{(), Tuple{}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}})(::Vector{Float64})
  from (::var"#g#7")(x) @ Main REPL[35]:4
  #self#::var"#g#7"{NamedTuple{(), Tuple{}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}}
Body::Tuple{Vector{Float64}, NamedTuple{(), Tuple{}}}
1 ─ %1 = Core.getfield(#self#, :C)::Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}
│   %2 = Core.getfield(#self#, :ps)::NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}
│   %3 = Core.getfield(#self#, :st)::Core.Const(NamedTuple())
│   %4 = (%1)(x, %2, %3)::Tuple{Vector{Float64}, NamedTuple{(), Tuple{}}}
└──      return %4

A globally scoped function accessing a global variable is not generally considered a closure; many languages with the former lack the latter. i is bound to the only closure in your examples.

This doesn’t work because C isn’t const. It didn’t matter for the @code_warntype C(... line because neither the callable nor arguments need to be assigned to const variables for the runtime dispatch and type inference, in fact dispatch is done with respect to the callable and the arguments.

The wider compiler is pretty complex, but you can definitely follow the type inference part, occasionally even do better. Unfortunately you don’t have direct control over the type inference process, so when that happens, the most you can do is manually annotate some variables to give the compiler more hints.

It doesn’t do value semantics, it’s still capturing a variable, only the let st = st part made a new st variable that is local to the let block and never reassigned. If you had reassigned it, the compiler can’t infer it anymore even if all reassignments occurred prior to capturing and didn’t even change the type; that is one example of how you could do type inference better than the compiler. The only way to “capture by value” is to directly store the value in a struct instance, which you can make callable.

If you want to make a global variable inferrable, the easiest way is to make it const, promising it’ll never be reassigned, so it’ll always have the same type. The 2nd easiest way is annotating the variable with a concrete type, much like how you annotate struct fields for type stability. Reassignment works if the values can be automatically converted to the annotated type, which isn’t necessarily concrete; if you have some unexpected type inference problems check if you unintentionally annotated an abstract type. There isn’t a provided way to automatically annotate the concrete type from the right-hand value, but there’s a macro with an atypical use of the local statement that accomplishes that. Unlike a const variable whose value can often be inlined, the typed global variable’s value is retrieved through a couple dereferences, but that bit of overhead is well worth the type stability.

1 Like

Wrap it into a Experimental Features | LuxDL Docs. See MNIST Classification using Neural ODEs | LuxDL Docs as well.

It is marked “experimental” but at this point it is being used in quite a few downstream packages DiffEqFlux and DeepEquilibriumModels so there is no plan to suddenly change the semantics.

It is in essence the same as a const global, so it comes with the same challenge that the type of st cannot change.