What happened to XLA.jl

Dear All,

I would like to ask, what happened to XLA.jl. When I checked the repository, it says it was archived. Is the project of building the Julia-to-XLA interface dead?

Best,
Jan

Not sure, and while I/you can still install XLA under FluxML this other XLA.jl is not achieved and I could also install (not both projects at the same time, I don’t think possible nor needed, at least in this case, to install projects with same name):

pkg> add https://github.com/JuliaTPU/XLA.jl

I see the code there is a bit older. Both depend on TensorFlow.jl and while it’s not archived:

Tensorflow.jl is in minimal maintenance mode
While it works, it is not receiving new features, and is bound to an old version, 1.13.1, of libtensorflow.

The general answer is the same, but there is some nuance depending on what you’re thinking of when you say “Julia-to-XLA interface”:

  1. A way to run Julia code + models on TPUs. This would require wrapping libtpu and updating JuliaTPU/XLA.jl to support a recent version of Julia. The latter may require adding features to the Julia compiler itself to support TPU compilation workflows.
  2. A way to get optimized linear algebra/ML code on any device (really CPU/GPU). This would require wrapping the XLA support libraries from TensorFlow (see nx/exla at main · elixir-nx/nx · GitHub for how this is done in another language) to replace the current PyCall path in FluxML/XLA.jl. It would also likely require a similar rethinking of the partial evaluation pipeline used in that package to make use of newer tools the Julia compiler provides (one keyword here is AbstractInterpreter).

In both cases, the biggest missing piece is someone willing to roll up their sleeves and actively work on the problem. Despite myself and others making similar calls to action like this before, however, thus far we have only heard crickets. That’s understandable, as I assume the intersection of people who want to use XLA and/or TPUs, want to use Julia, and would be willing to take on this project is basically the empty set. I would be very happy to be proven wrong on that :slight_smile:

7 Likes

Thank you for your reply. I was thinking about point 2, for my type of problem (basically solving a system of functional equations), TPU doesn’t make much sense. However, being a poor macroeconomist, I will probably stick to JAX. :sweat_smile: I was originally working in Julia, however, a year ago, I was lured by JAX. XLA compilation (kernel fusions are really a big thing in my application) + vmap + nested autodiff out of the box are really hard-to-beat propositions. I am afraid that so far, the Julia ecosystem doesn’t have anything close to match this, at least for not that skilled user like me…

shameless self-plug showing JAX while being much less flexible (unroll loop, no if-else branch allowed, no Vector only fixed-length array), sometimes isn’t even faster than Julia

As @jling mentioned, it’s not hard to beat many types of kernel fusion JAX does. However, it is also true that doing so often gives up AD or GPU support.

Given that you mention nested AD as a selling point, can you elaborate on some of your macroeconomic use cases which use it? I’d also be curious how important (or not) GPU support is for your work. Getting all those lined up is a major pain point in the Julia ecosystem right now, and having a concrete example in a niche which many Julians have experience in (econ) but the Python ecosystem (at least the corner JAX lives in) has not explicitly targeted would be very interesting. Selfishly, I would hope it gets people motivated to make nested AD (+ maybe GPU) work well instead of just thinking it’s not worthwhile because the only use cases are deep learning models.

3 Likes

Thank you for your interest!

Small disclaimer: what I work on is really a niche application, even within my field.

First, to those nested derivatives:

Sometimes, I work on continuous time models, where the functional equation to solve is a second-order HJB equation coupled with an associated system of functional Kuhn-Tucker conditions. The loss function is then a squared residual of those equations, hence the loss function includes derivatives of the value function neural network (up to second order), and I need then to differentiate this whole thing w.r.t. network parameters. Unfortunately, some details of those problems (state-constraint boundary conditions) prevent me from using NeuralPDE.jl. The second case is so-called generalized Euler equations in optimal taxation problems, in this case, it is a discrete-time functional equations system, but it also includes derivatives of some involved unknown functions… which again doesn’t fit any pre-packaged Julia solver.

So, I need to write those loss functions & training loops using NN and autodiff primitives. In JAX, it is rather straightforward. Those operators are composable. I can compose arbitrary many grad operators together, I can jit compile it, then vmap it, pmap it, and differentiate again… In Julia, I had to deal with all the various kinds of ‘can not differentiate through foreign call expressions’. To make it work, I had to tap into defining some specialized adjoint rules for Zygote to differentiate through ForwardDiff derivatives… Can be done, but it is really annoying, and it requires boilerplate code…

Even when I get it to work, the resulting code was way slower than the JAX code. The key problem in my type of application (continuous time, but especially discrete time, which I work with most of the time) is that neural networks used in those applications tend to be fairly small. 16,32 neurons, in large models, maybe up to 500, but quite often in the lower part of this range. However, one evaluation of loss function may require a LOT of network evaluations, because those equations involve expectations of t+1 objects => you need to integrate over future uncertainty (in some of applications, I needed a couple of hundred integration nodes, and for each of them, one needs to evaluate a policy/value network), also the loss function is evaluated at many points in state space.

While I was working on those problems in Julia, I reached a conclusion, that GPUs aren’t really helpful, because those matrix-vector multiplies simply aren’t big enough to offset the GPU overhead… When I switched to JAX & compiled the whole code, voila, even for relatively small problems (32 neurons), I get an order of magnitude speed-up on GPU relative to CPU, because XLA was able to fuse those matvec operations into matmuls. Same for the simulation part of my algorithm, which requires many evaluations of policy function networks. XLA fusion again does a lot of magic here.

While I believe, that for this type of application, Julia should have an edge down the road (Diffractor/Enzyme), using JAX now is a clear choice for me. There are many orders of magnitude difference in ergonomics, both when it comes to nested derivations, and more importantly to GPU/distributed deployment. And yes, I forget to mention the whole vmap business, so far, I didn’t find anything nearly as convenient in Julia. Sure, broadcasting is nice, but the ability to specify vectorization dimensions is irreplaceable for me.

JAX has some rough edges but trying to code some more complex problems in Julia usually ends up me getting hurt :smiley: & returning to JAX, even when it comes to NeuralODEs.

I would say, up to a very large degree, it is a problem with my coding skills. I am sure that for really professional users like you or @ChrisRackauckas, Julia is immensely productive language, but for ‘educated laymen’ like me, JAX offers an order of magnitude better tradeoff.

Sorry for a long & quite unstructured post.

5 Likes

Related issues

3 Likes

FWIW you are likely a more professional user than I, someone who doesn’t write Julia for work/research :wink: . Thanks for providing such a detailed post. I think examples like this are important to understand where the ecosystem and possibly the language are still lacking. Not just on a technical level of checking feature boxes, but also when it comes to user experience.

1 Like

Apologies for reviving an older thread but I’d certainly be interested in contributing to such a project if it comes about. I share @Honza9723’s opinion that JAX’s vmap syntax is fantastic and so I’ve been looking to see how much of JAX can be put into Julia without XLA. There are fundamental differences in how JAX and Julia work, but I’d love to master and contribute to both tools’ ecosystems.

4 Likes

I stumble upon this thread by chance and just want to say that I’m in a very similar situation myself. I invested quite some time learning Julia with the idea that I could write efficient code without worrying about broadcasting. But I soon discovered that it did not work out with Zygote, and I’m now either writing vectorized function (so keeping to a declarative approach as I would with numpy) or writing the derivative by hand with ChainRules, which is never a pleasant experience.

I have little knowledge of scientific computing and I may be wrong; besides, I was told that projects such as Diffractor typically aimed at transforming loops into vectorized operations. However, in the meantime, I’m disappointed by the switch from Python to Julia – even though I’m convinced that it is 100% better in principle for scientific computing.

1 Like

It would be good if you could share more about the work you do. Reading @Honza9723’s post was incredibly helpful, and I think we need more of these experience reports so that ML/DL stuff can be a “squeak(ier) wheel” in the Julia community.

For now, what I’ll say is that there does appear to be a disconnect between the idiomatic way of writing fast Julia code and what is required to be compatible with AD + GPU. Think favouring loops over vectorization, mutation over allocation, etc. JAX et al. don’t have this problem because they do everything vectorized + out of place, and they dictate what ML code looks like. Is it any surprise then that if both optimize for idiomatic code, your usual DL model doesn’t get a lot of performance goodies on the Julia side?

So given all this, I think there is a need for some alignment on what folks in the community consider “scientific computing” to be. My feeling is that ML has traditionally not fit perfectly into that box, and that may be why it hasn’t received the same level of attention it does over in Python land.

Quick addendum on AD: both of the experimental ones in development (Diffractor and Enzyme) try to bridge the gap in some way. Enzyme makes scalar code fast, lets you write loops + mutation and even supports GPU kernels. However, it still requires writing a lot of custom rules (currently in another AD, in future with its own rule system) to make code that bounces between device and host often work (so most NNs). Nested differentiation also seems like a question mark. Diffractor should in theory smoke JAX at higher-level differentiation in general, but it doesn’t provide a powerful array/linear algebra optimizer like XLA (which is how JAX can be remotely performance competitive in the first place) and lacks mutation support (which you’d need to write efficient code if you don’t have such an optimizer). It also faces the small, small issue of not having any dev roadmap or timelines, but given the previous points I don’t think you need to think about it too much anyhow.

4 Likes

I don’t mean to pick on you, but it seems like most of us (myself included) would be interested in contributing to but not leading such a project :wink: . And that’s not to disparage the crowd, because some of the contribution offers have been very generous (e.g. hardware access). So there remains the question of how to get the ball rolling, but maybe there’s a chance.

As I noted above, Zygote and now Diffractor basically assume you’re writing code like you would in JAX. The missing part is how to make that code fast. For vmap specifically, there have been a couple of attempts to write something similar but more idiomatic for Julia (personally, I find things like tuple in_axes can make the vmap syntax quite unpleasant), but there’s a much bigger ocean of ops to boil out there and it would likely require more resources than one person working off the side of their desk.

More generally, I would be curious to know what concrete steps could be taken to make the “if you need it, build it” ethos that pervades the rest of the Julia ecosystem more viable for the ML space. If there are common bottlenecks potential library authors/contributors are running into, those should get quite a bit more attention than they are now (i.e. zero).

1 Like

Sure. First of all, sorry for my post which was indeed unhelpful. I’ll try to be more specific in how I use Julia. A disclaimer though: I’m from a math background and do not understand well the subtleties of the languages or libraries I use. Also, although the question – what would be a perfect programming language for scientific computing and ML – interests me, I’ve little hindsight with that.

I was not satisfied with Python (and Pytorch/Jax) for some time and made the switch to Julia when I started a project on polynomials. With those, using broadcasted operations is a mess, and it’s just so much easier to write loops explicitly (there are lots of recursive operations involved, e.g. evaluation with Clenshaw’s algorithm). So I learned Julia and hoped to be able to stop writing vectorized code and just use loops explicitly, in an imperative coding style. At first, this was quite a relief, and I became more productive: instead of overthinking how to vectorize a loop, you just write it down, and voila, it works.
However, I soon realized that Zygote did not like that. So now I either write reverse rules explicitly with ChainRules, or use tools like Tullio.jl which take care of all the optimization for me, even providing a gradient with forward rules. This package is really amazing.

That’s where I realized the value of a declarative style of coding, with tools like einsum (the family of which Tullio belongs to) or the more recent einops. With those, you understand in a single line what the code is doing and you can be confident that the implementation is efficient.

At that point, I may seem to contradict myself: I came to Julia for a more imperative style of coding, being able to write all loops explicitly; but due to Zygote limitations I realized that high-level tools implementing e.g. einsum are so useful for me. Nevertheless, I believe this contradiction is not surprising. I’m not a specialist in implementation (and I believe nobody can encompass all the knowledge between low-level implementation and high-level maths), so I rely on other people’s work to make good optimization, like the excellent Tullio package. Another example: when doing matrix multiplication, I don’t want to write a for loop, but instead expect the code I’m using to leverage a century of research to use the most efficient routine available (which would be something in BLAS).

Summing up: my ideal ML framework

  • I need libraries that provide top-notch performances with automatic differentiation; where I can run my model on GPU with a single function call; where I know precisely where are the bottleneck in speed. I’m not satisfied with Zygote for that (I did not try the GPU part with Julia yet).
  • I love not having the 2 languages barrier. It’s great being able to implement super efficiently a specific function that would otherwise be a bottleneck in speed, to overcome the limitations of high-level tools which are, by definition, not specific enough.

From my reading, I understood that automatic differentiation is a very difficult problem and that there is a tradeoff between these two points. This excellent blog post explains how Jax manages to be super-efficient for quasi-static algorithms. It also highlights the risks of having the tools we use (Jax, Pytorch) shape the research we do (only exploring neural network-like models with a fixed number of training steps), while Julia aims for more versatility. However, for now, the mental load I have to make Julia efficient is a bit too much for someone with little computer science background like me.

Sorry for the confusing answer and the contradictions – I’m still struggling to know what I need from a programming language for my research haha. Let me know if I can clarify some points!

5 Likes