Julia "inference lattice" vs "type lattice" (from the TPU paper)

I’m reading Keno’s and Elliot’s TPU paper now.

I came across this quote:

Of importance to the current work is that the Julia inference lattice is significantly finer than its type lattice, allowing increased inference precision crucial for the success of this work.

Can anyone explain what that means (what are inference lattice and type lattice in this case?), what it implies (what is the increased inference precision it refers to?), and why that’s important (to TPU stuff, i guess?)?

:smiley: haha thanks! :slight_smile:

@Keno is the best person to answer this question, as he did the nitty-gritty inference work for that paper. I will take a shot at answering this though. I know you already know many of the things I’m about to type up, but for the benefit of others that are reading along, I’m going to explain things from the basics a little more completely than I otherwise might.

Julia gives us the benefit of being able to write extremely dynamic code (variable types can change, methods can be overwritten, you can run eval(), etc…) while still getting good performance because of the nature of the just-in-time compiler that is able to look at the next “chunk” of code to run, and turn that next chunk into “static” machine code just before we run it. Because much of the dynamism of Julia is based on types (e.g. which chunk of machine code gets called when I say dot(a, b) depends heavily on the types of a and b) being able to take the state of the program at time point t and fully qualify what the types of variables will be at time point t + N is very advantageous to us: it allows us to group chunks of the program together into “static subregions” that can then be lowered to machine code and run without ever needing to “come up for air”, look around at the state of the world, and then decide what to run next. In essence; the more we can prove about the program ahead of time, the larger the static sub-chunks we can run.

In a typical Julia program, the penalty for that kind of breaking out is very small, almost miniscule. You run a chunk of lowered code, then you run a different function to figure out what chunk to run next, and you go and run that chunk. Ignoring the amount of time we spend trying to figure out what else to run, the amount of overhead is negligible. However, when running on an accelerated device, there is suddenly a whole new source of overhead; when we compile Julia programs to run on the TPU, it is critical that we maximize as much as possible the amount of time the TPU is busy without breaking out to Julia, because the Julia compiler itself is running on the host CPU, and the TPU is physically distant; we are controlling it through a network connection. So context switches that would normally take nanoseconds are instead taking seven orders of magnitude longer, which isn’t something that makes the people happy.

And so, Keno took a wrench to the compiler to convince it to try extra extra hard to infer larger chunks of a Julia program as a single static chunk. How big, you ask? Well, in the TPU paper, we are able to compile the entire VGG 19 backward and forward pass into a single static chunk, so that the entire thing can run without ever once breaking out. This requires a custom version of Julia (see the kf/tpu3 branch) both to raise limits on how much time/effort Julia will put into type inference, as well as tweaking some internal workings of the compiler to not stack-overflow when juggling so much code at once. I don’t know how many times the current master of Julia would typically “break out” of static execution in this program (with TPUs we’ve made it all-or-nothing; if you try to run something that can’t be completely inferred, XLA.jl will yell at you and you have to break your function down into smaller pieces, which isn’t very user friendly but that’s how it is at the moment) but I can tell you it’s more than zero. :wink:

And so, I believe the type lattice refers to the hierarchy and relationship of types as statically defined and declared within your program; e.g. the information that is available pre-type-inference, whereas the inference lattice refers to the hierarchy and relationship of types as known post-type-inference. Naturally the information we get from type inference increases the precision of our knowledge about what types things are, and as explained above, this is crucial to getting TPUs to be able to run significant chunks of Julia code.


Does that mean there is no memory allocation?

Can this same work be used to statically compile julia programs for distribution?

Does that mean there is no memory allocation?

While in the case of TPUs it’s true that we don’t do typical memory allocation, that is orthogonal to what I’m talking about above with respect to “static subsegments” of a program. A program can be fully type inferable (such that I know the types of every variable throughout the entire program, and thus know exactly what machine code the entire program should be lowered to) and yet still allocate memory. The one does not impact the other. Similarly, I can have a program that allocates no memory, and yet is not type-inferrable. Example:

function randtype()
    if randn() > 0.5
        return "hello"
        return 1.0
y = [randtype(), randtype()]
x = sum(y)

In the above program, there are not necessarily any memory allocations (the compiler may be smart enough to realize that all memory sizes are known at compile-time, and may not need to allocate any memory). However, exactly which sum() method is called (sum(::Vector{String}), sum(::Vector{Float64}) or sum(::Vector{Any})) is unknowable without actually constructing y, then looking at its type. And so Julia cannot determine ahead of time what machine code to lower this down to; it must look at the type of y and then choose which method of sum() to call. This is what I mean by the program is not “static”, it is “dynamic”.

Can this same work be used to statically compile julia programs for distribution?

While more complete type inference does mean that Julia programs may require the runtime/compiler less, I think it would still be challenging, in general, to build a meaningful program as a completely static subregion. Machine Learning models are something of a nice case for this, because you tend to be mashing up tensors of Float64s in fairly standardized ways, and so there’s not all that much dynamism happening anyway. There is some, and the work that Keno did to get this working should not be disregarded, but there is still an awful lot of stuff that the world at large does with Julia that cannot be completely type-inferred. So while this does help the compiler identify larger static sub-regions of Julia code, I still think statically compiled Julia programs are going to need to embed the language runtime and compiler to do most interesting work.

It is also worth noting that the kf/tpu3 branch can take significantly longer to compile large chunks of Julia code, because it tries much harder to do this inference. Compiling the VGG19 model can take upwards of two minutes, which is mildly excruciating.


This was a great explanation, thanks!

AFAIK (correct me if I’m wrong, compiler folks), all elements of the front-end type lattice are elements of inference’s type lattice, but not vice versa. These extra elements in the inference lattice allow the compiler’s abstract interpretation process to represent/capture more precise information (like constant information) than would be possible with the normal type lattice, thus enabling inference to produce more optimal results than it could otherwise. AFAICT, this is what the quote in the OP was referring to:

Julia inference lattice is significantly finer than its type lattice, allowing increased inference precision crucial for the success of this work.

To give a concrete example of elements which are in the inference lattice but not the front-end type lattice, one can look at Const and PartialTuple. For example, Core.Compiler.Const(1) <: Integer is an error (and doesn’t make sense), but Core.Compiler.:⊑(Core.Compiler.Const(1), Integer) is true.

For curious folks, I highly recommend checking out typelattice.jl in Julia’s compiler.


How come TPU.jl doesn’t use cassette to implement the compiler passes?