State of machine learning in Julia

It’s not the same or similar thing as the E-graph, but instead it’s similar to the interfaces the E-graphs are acting on. Maybe the easiest way to describe it by saying what is the same or similar. The Python bytecode is like “the Julia IR”. Of course, as an optimizing compiler, there isn’t a singular IR, instead there are stages: the untyped IR, the typed IR, and the LLVM IR. Cassette and IRTools, the tools on which Zygote.jl was built (some notable others are AutoPreallocation.jl, SparsityDetection.jl, etc.), are probably the most similar to TorchDynamo in that on untyped syntactic IR it is a tool that transforms to another untyped syntactic IR.

It turns out that for Julia this was a bad idea because (a) the meaning of code can depend (heavily) on types, and (b) this is before compiler optimizations, and so mixing compiler optimizations with automatic differentiation is impossible. Thus Julia v1.7 added an AbstractInterpreter interface to Julia Base itself for acting on typed IR, which is then used by packages like EscapeAnalysis.jl and Diffractor.jl to write compiler passes on typed IR. And of course LLVM IR has standard interpretation techniques along with Enzyme.jl which is an AD written on LLVM IR.

So TorchDynamo is probably most similar to Cassette/IRTools, but you could also say it’s like AbstractInterpreter in that it’s acting on “the true IR of Python”, where the true IR of Julia is typed when it has all of its information while in Python it is not. But this story is why Zygote has its compile-time issues, higher order AD issues, and why all of the tooling is moving to not just a new AD tool but an entirely different IR target and compiler tool stack (note this doesn’t imply that will happen to TorchDynamo, unless they start rewriting their AD to be source-to-source on Python bytecode, but there’s precedent of that in tangent which didn’t find a nice home). Note that these tools aren’t just for AD. For example, there are PRs to Julia’s Base which are automatically analyzing loops and removing repeated allocations of immutable arrays where they are written using the AbstractInterpreter compiler plugin interface.

So that still doesn’t answer how the heck E-graphs comes into the story because I haven’t described how you write a compiler pass. It doesn’t matter what level of IR you’re on, it’s basically just a function IR->IR. So where in their blog post they say “just add code here”

def custom_compiler(graph: torch.fx.GraphModule) → Callable:
    # do cool compiler optimizations here
    return graph.forward
    
with torchdynamo.optimize(custom_compiler):
    # any PyTorch code
    # custom_compiler() is called to optimize extracted fragments
    # should reach a fixed point where nothing new is compiled
    
# Optionally:
with torchdynamo.run():
    # any PyTorch code
    # previosly compiled artifacts are reused
    # provides a quiescence guarantee, without compiles

Well, that’s true in any of these systems, just like in macros. But if you’ve ever written a macro, you’ll know that walking expression graphs is a tedious process to get correct. Wouldn’t it be nice if compiler optimizations for mathematical ideas could be expressed mathematically, and the associated compiler pass could be generated? It turns out that all Symbolics tooling really is is just tooling that performs rewrites on some IR. So Symbolics.jl has an IR that uses SymbolicUtils.jl’s rewriters and MetaTheory.jl’s E-graphs to transform symbolic IR → symbolic IR, but what we have done is made those rewrite tools generic to the IR and boom now it’s a compiler optimization pass generator.

That means you can say define an E-graph that acts on Julia typed IR and spits out the typed IR with the desired simplifications described mathematically. This is what we mean by “democratization of writing compiler passes”: we are trying to use this to build a system so that people who want to add a new linear algebra simplification pass to the Julia typed IR do not need to learn all of the details of the AbstractInterpreter and Julia Typed IR definition, and instead just write a few mathematical equalities and boom it generates a compiler pass which then generates the transformed IR. So think of the E-graphs as replacing this requirement that someone writes a function like def custom_compiler(graph: torch.fx.GraphModule) → Callable: that digs through some expression graph. Instead you just write

Man, this came out longer than expected. But since it describes why Zygote is being replaced with Diffractor and Enzyme I guess it’s a useful description for many other reasons than the original question :sweat_smile:

16 Likes