@Keno is the best person to answer this question, as he did the nitty-gritty inference work for that paper. I will take a shot at answering this though. I know you already know many of the things I’m about to type up, but for the benefit of others that are reading along, I’m going to explain things from the basics a little more completely than I otherwise might.
Julia gives us the benefit of being able to write extremely dynamic code (variable types can change, methods can be overwritten, you can run eval()
, etc…) while still getting good performance because of the nature of the just-in-time compiler that is able to look at the next “chunk” of code to run, and turn that next chunk into “static” machine code just before we run it. Because much of the dynamism of Julia is based on types (e.g. which chunk of machine code gets called when I say dot(a, b)
depends heavily on the types of a
and b
) being able to take the state of the program at time point t
and fully qualify what the types of variables will be at time point t + N
is very advantageous to us: it allows us to group chunks of the program together into “static subregions” that can then be lowered to machine code and run without ever needing to “come up for air”, look around at the state of the world, and then decide what to run next. In essence; the more we can prove about the program ahead of time, the larger the static sub-chunks we can run.
In a typical Julia program, the penalty for that kind of breaking out is very small, almost miniscule. You run a chunk of lowered code, then you run a different function to figure out what chunk to run next, and you go and run that chunk. Ignoring the amount of time we spend trying to figure out what else to run, the amount of overhead is negligible. However, when running on an accelerated device, there is suddenly a whole new source of overhead; when we compile Julia programs to run on the TPU, it is critical that we maximize as much as possible the amount of time the TPU is busy without breaking out to Julia, because the Julia compiler itself is running on the host CPU, and the TPU is physically distant; we are controlling it through a network connection. So context switches that would normally take nanoseconds are instead taking seven orders of magnitude longer, which isn’t something that makes the people happy.
And so, Keno took a wrench to the compiler to convince it to try extra extra hard to infer larger chunks of a Julia program as a single static chunk. How big, you ask? Well, in the TPU paper, we are able to compile the entire VGG 19 backward and forward pass into a single static chunk, so that the entire thing can run without ever once breaking out. This requires a custom version of Julia (see the kf/tpu3
branch) both to raise limits on how much time/effort Julia will put into type inference, as well as tweaking some internal workings of the compiler to not stack-overflow when juggling so much code at once. I don’t know how many times the current master
of Julia would typically “break out” of static execution in this program (with TPUs we’ve made it all-or-nothing; if you try to run something that can’t be completely inferred, XLA.jl
will yell at you and you have to break your function down into smaller pieces, which isn’t very user friendly but that’s how it is at the moment) but I can tell you it’s more than zero.
And so, I believe the type lattice refers to the hierarchy and relationship of types as statically defined and declared within your program; e.g. the information that is available pre-type-inference, whereas the inference lattice refers to the hierarchy and relationship of types as known post-type-inference. Naturally the information we get from type inference increases the precision of our knowledge about what types things are, and as explained above, this is crucial to getting TPUs to be able to run significant chunks of Julia code.