State of machine learning in Julia

After the Twitter space Q&A @logankilpatrick hosted yesterday on “The future of machine learning and why it looks a lot like Julia,” I thought it would be useful to accumulate some community responses to a few questions about the current state of machine learning in Julia:

  1. Where does ML in Julia really shine today? Where do you see the ecosystem outperforming other popular ML frameworks (e.g. PyTorch, Flax, etc) in the near future, and why?
  2. Where is Julia’s ML ecosystem currently inferior in features or performance? What’s the realistic timeline for it becoming competitive in these areas?
  3. How do Julia’s ML packages for “standard ML” (e.g. deep learning) compare with popular alternatives in terms of performance (faster, slower, same order of magnitude)? Are there regularly updated benchmarks somewhere?
  4. What don’t we know yet but suspect is an important experiment to make for benchmarking against popular ML alternatives?
  5. If a company or institution is considering creating multi-year positions to contribute to Julia’s ML ecosystem, what is the best case you can make why they should do this? What contributions would be most impactful?
  6. What is the best case you can make why independent developers who work with other frameworks should consider contributing to Julia’s ML ecosystem?
  7. What packages do you tend to reach for for some specific tasks? Why those packages vs some other Julia package or one in another language? What do you wish existed but currently doesn’t?

Question 1: Where does Julia Shine

For scientific machine learning (also known as physics-informed learning, or science-guided AI, or expert-guided AI, etc. see the note at the bottom), Julia is really good. If you don’t know what that is, check out this recent seminar talk which walks through SciML for model discovery in epidemics, climate modeling, and more.

The SciML Benchmarks in Neural ODEs and other such dynamical models are pretty damn good. We’re talking 100x, 1000x, etc. across tons of different examples. Here are some. Note that most examples don’t even run in torchdiffeq since it uses the non-robust adjoints and no real stiff ODE solvers, so the ones that are benchmarked are only the easiest cases so that torchdiffeq isn’t just exiting early and outputting Inf for the gradients (I guess that’s one way to be fast :sweat_smile:)

Similarly we see torchsde performance on standard SDEs to be not great:

And those even have input from the package’s devs on how to optimize the benchmarks… so… :sweat_smile:. Though note that all of these benchmarks are on examples which are more realistic to scientific machine learning, so smaller kernels, more nonlinearity, at least mild stiffness, etc. I’ll address “big kernel computing” in the second part.

Another thing that people might not know about is that dynamical models also fall into this space. Simply using a good ODE solver is much better than using a “physics engine”. There’s a paper which should DiffEqFlux outperforming Mujuco and DiffTaichi by a good order of magnitude.

There’s a lot more to say too. Managing linear solvers and preconditioners is essential to this kind of stuff, and these other libraries don’t even have them. So really, direct usage of Sundials from C or Fortran is our only real competitor in this space, but even then that won’t mix with vjps of ML libraries automatically, which we know will decrease performance by two orders of magnitude, and we do beat it in most of the SciMLBenchmarks (most but not all). So even if you use C/Fortran Sundials you have to define every Jacobian, vjp, sparsity, etc. function to even get close to what DifferentialEquations.jl does by default. So I would assume that for the vast majority of the population they wouldn’t hit DiffEq speeds even in C/Fortran anymore in this domain :wink:.

The other thing is differentiable programming on non quasi-static programs, and I’ll let this blog speak for itself.

Another way to define how far ahead we are is to note that one of the NeurIPS spotlight papers is a method that we wrote a tutorial on in 2019 (this kind of thing isn’t too uncommon, it must’ve happened like 10 times by now). Oh wait, they only did the Newton part, not Newton-Krylov, and they didn’t do forward-over-adjoint which is known to be more efficient (as mentioned in Griewank’s AD book, though it’s easy to see with applications), so that paper doesn’t even match our tutorial :sweat_smile:. Feel free to make figures and publish our tutorials if you want.

Question 2: where is the Julia ML ecosystem currently inferior?

This blog touches on this topic in some detail:

(and the Hacker News discussion is quite good for this one). The Julia ML/AD tools are not inferior by design, but in some cases by focus. If you’re only dealing with big kernels, i.e. lots of big matrix multiplications, then PyTorch does not have measurable overhead. If you’re only dealing with such kernels, then XLA will perform fusion operations like A*v1 + A*v2 => A*[v1;v2], changing BLAS2 to BLAS3 and exploiting more parallelism. Another case is fusions in conv kernels. cudnn has a billion kernels, PyTorch listed them out. Here’s a snippet:


So instead of relu(conv(x)) .+ y, on GPUs you’d want to do cudnn_convolution_add_relu. Similarly some of these extra kernels exist for RNNs. XLA, and thus both TensorFlow and Jax, will do these fusions automatically.

Julia does not perform those kinds of optimizations in its “ML compiler” because it’s just using the Julia compiler, and you wouldn’t necessary want to do that on all Julia codes because it’s kind of like @fastmath in that it changes the floating point results. So somehow we need to add such compiler optimizations to the language which only apply in specific user contexts. Even better, it would be nice if mathematical/ML users could help build and maintain these compiler optimizations like they do with automatic differentiation rules for ChainRules.jl. Not everyone is a compiler engineer, but everyone can help with linear algebra equalities. That’s the purpose of the E-graph project:

However, that is still in its elementary stages and cannot even apply rules to arrays (though Shashi has started that work).

So in theory getting ML performance is simple: you just use fast kernels. In practice, the hard part is allowing users to write simple Julia code but then to allow for, in limited contexts, the compiler to change the high level calls of their code to more efficient kernels. The AbstractInterpreter will give us the tools to do this, but it’s still just the near future. The Julia ML tools picked a much larger scope and that’s good for some things but the downside is that these optimizations are much harder to write.

This does not effect my SciML use cases but is probably the biggest performance issue with standard ML in Julia. But while finishing these optimizations would make Julia really ergonomic for both building and using ML packages, I’m not convinced it will “win over” the standard ML community because you still would only expect to match performance in that domain, so I’m not convinced we have a silver bullet there.

Question 3: How well does Julia perform in “standard ML”?

I don’t think each individual benchmark is interesting. They all say pretty much the same thing as my response to 2:

  1. Julia’s kernel speeds are fine. On CPUs we’re doing really well, beating most other things with GitHub - JuliaLinearAlgebra/Octavian.jl: Multi-threaded BLAS-like library that provides pure Julia matrix multiplication and such. On GPUs everyone is just calling the same cudnn etc. so it’s a battle of who calls the better kernels and with the right arguments.
  2. Julia’s AD speeds are fine. Zygote can have some overhead, but it’s actually rather fast in most contexts compared to Jax/PyTorch/TensorFlow. Specifically, PyTorch overhead is much higher but it’s not even measurable in standard ML workflows anyways. One matrix multiplication of a large enough matrix eats up allocation issues or other O(n) stuff.
  3. Julia does not fuse kernels, so in most benchmarks if you look at it you just go “oh, it’s not fusing this conv” or “this RNN cudnn call”.

So a lot of the issues which mention PyTorch, like RNN design for efficient CUDNN usage · Issue #1365 · FluxML/Flux.jl · GitHub, are really just re-design issues to try and make Flux more readily call better kernels by manually fusing some things. So that’s really the main standard ML issue at the moment.

Question 4: What important experiments and benchmarks should we be tracking?

XLA’s distributed scheduler is pretty good. As we are thinking about scaling, we should probably ignore PyTorch and look at DaggerFlux vs TensorFlow/Jax. XLA has more freedom to change operations around so I think it should be the winner here, and we will need to use e-graphs tricks to match it.

One other thing to note though is that there is a “missing middle in automatic differentiation” where one has to choose between loopy mutating code (with Enzyme) vs kernel linear algebra code (with Zygote/Diffractor), and mixing the two code styles does not work right now. For a discussion on that, see:

That said, PyTorch, Jax, and TensorFlow also have this issue so it’s not really inferior, and Julia is closer to solving it than the others. But it’s a big PITA and something we really need to address to make Julia’s differentiable programming feel truly better than the alternatives.

Question 5: Which companies and institutions are looking to give multi-year positions?

I’m not sure, but I see it around. Both Julia Computing and Pumas-AI have many folks doing SciML though, and if you’re interested in SciML stuff please get in touch.

This topic also explains the parts of the MIT Julia Lab, where there’s a SciML crew focusing on SciML models and applications, and a compiler/symbolics crew working on these kinds of “customizable compiler for science” issues.

Question 6: Why should independent developers consider contributing?

It’s easy to specialize it towards weird research. In fact, there’s an entire research domain in non-quasi-static ML algorithms that’s just full of potential gems. Writing your own kernels gets you new algorithms that aren’t in the rut. Our implicit layer tooling all composes, and we’re far enough ahead that there are entire papers that are just what we’ve been doing for years.

And of course, all of the differential equation stuff mentioned at the beginning.

Question 7: What packages do you use and which packages do you wish existed?

I tend to reach to Flux when I need to, but try to stick to DiffEqFlux. Flux is just the most complete in terms of kernels that exist, but its style irks me. I wish there was a Flux which did not use implicit parameters and instead used explicit parameters. I would like those parameters to be represented by ComponentArrays

If you haven’t seen the ComponentArrays DiffEqFlux.jl example it’s really nice:

function dudt(u, p, t)
    @unpack L1, L2 = p
    return L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b

That would make there be no implicit global state and everything would be explicit but with nice syntax. Since ComponentArrays is a flat contiguous vector with helper indexers (that return views), it works nicely with linear algebra for putting into BFGS. Mixing that with a universal optimizer interface GalacticOptim.jl

Would give me everything I need. It would be a whole lot easier to debug and optimize too. So I would really consider either making that kind of ML package or changing Flux to explicit parameters (which is somewhat underway with Optimisers.jl).

[Note on SciML terminology. It’s pretty funny that these days people attribute SciML to the SciML Scientific Machine Learning Open-Source Software Organization. The original attribution was a major workshop by the US DoE which established the term and essentially proclaimed it to be a field. We started working on this and were JuliaDiffEq before, with half of the repos not being differential equation solvers anymore, so it made sense to change to being “The SciML Scientific Machine Learning Open-Source Software Organization”, which is always just abbreviated to SciML. Soon, the SciML org became synonymous with the term, and so now people are less inclined to use the term as it refers more to us than the field, and so now people ask why we didn’t adopt “standard” terminology like “science-guided AI” which was first developed to avoid referring to us :laughing:. Fun anecdote to show where our community is in this space.]


A note for anyone wondering “how is this true when Zygote takes 20/60/1000 seconds or more to give me gradients?”: runtime performance is generally on par and helped by lower per-op overhead. Almost all of that latency is coming from compilation (the source-to-source part of Zygote). This forms the lion’s share of what has been called “time to first gradient (TTFG)”. If you’re seeing pathological compilation latency or significantly poorer runtime performance and can whip up a MWE, please file an issue.


I’ll offer a perspective from someone who (as a conscious choice) primarily uses Python over Julia. I work with, and maintain libraries for, all of PyTorch, JAX, and Julia.

For context my answers will draw some parallels between:

  • JAX, with Equinox for neural networks;
  • Julia, with Flux for neural networks;

as these actually feel remarkably similar. JAX and Julia are both based around jit-compilers; both ubiquitously perform program transforms via homoiconicity. Equinox and Flux both build models using the same Flux.@functor-style way of thinking about things. etc.

Question 1: where does ML-in-Julia shine?

(A) Runtime speed.

Standard Julia story there, really. This is most noticable compared to PyTorch, at least when doing operations that aren’t just BLAS/cuDNN/etc.-dominated. (JAX is generally faster in my experience.)

(B) Compilation speed.

No, really! Julia is substantially faster than JAX on this front. (It really doesn’t help that JAX is essentially a compiler written in Python. JAX is a lovely framework, but IMO it would have been better to handle its program transformations in another language.)

It’s been great watching the recent progress here in Julia.

(C) Introspection.

Julia offers tools like @code_warntype, @code_native etc. Meanwhile JAX offers almost nothing. (Once you hit the XLA backend, it becomes inscrutable.) For example I’ve recently had to navigate some serious performance bugs in the XLA compiler, essentially by trial-and-error.

(D) Julia is a programming language, not a DSL.

JAX/XLA have limitations like not being able to backpropagate while loops, or being able to specify when to modify a buffer in-place. As a “full” programming language, Julia doesn’t share these limitations.

Julia offers native syntax, over e.g. jax.lax.fori_loop(...).

(PyTorch does just fine on this front, though.)

Question 2.

(A) Poor documentation.

If I want to do the equivalent of PyTorch’s detach or JAX’s stop_gradient, how should I do that in Flux?

First of all, it’s not in the Flux documentation. Instead it’s in the separate Zygote documentation. So you have to check both.

Once you’ve determined which set of documentation you need to look in, there are the entirely separate Zygote.dropgrad and Zygote.ignore.

What’s the difference? Unclear. Will they sometimes throw mysterious errors? Yes. Do I actually know which to use at this point? Nope.

(B) Inscrutable errors.

Whenever the developer misuses a library, the compilation errors messages returned are typically more akin to “C++ -template-verbiage” than “helpful-Rust-compiler”. (That is to say, less than helpful.) Especially when coupled with point (A), it can feel near-impossible to figure out what one actually did wrong.

Moreover at least a few times I’ve had cases where what I did was theoretically correct, and the error was actually reflective of a bug in the library. (Incidentally, Julia provides very few tools to library authors to verify the correctness of their work.)

Put simply, the trial-and-error development process is slow.

(C) Unreliable gradients

I remember all too un-fondly a time in which one of my models was failing to train. I spent multiple months on-and-off trying to get it working, trying every trick I could think of.

Eventually (eventually) I found the error: Zygote was returning incorrect gradients. After having spent so much energy wrestling with points (A) and (B) above, this was the point where I simply gave up. Two hours of development work later, I had the model successfully training… in PyTorch.

(D) Low code quality

It’s pretty common to see posts on this forum saying “XYZ doesn’t work”, followed by a reply from one of the library maintainers stating something like “This is an upstream bug in the new version a.b.c of the ABC library, which XYZ depends upon. We’ll get a fix pushed ASAP.”

Getting fixes pushed ASAP is great, of course. What’s bad is that the error happened in the first place. In contrast I essentially never get this experience as an end user of PyTorch or JAX.

Code quality is generally low in Julia packages. (Perhaps because there’s an above-average number of people from academia etc., without any formal training in software development?)

Even in the major well-known well-respected Julia packages, I see obvious cases of unused local variables, dead code branches that can never be reached, etc.

In Python these are things that a linter (or code review!) would catch. And the use of such linters is ubiquitous. (Moreover in something like Rust, the compiler would catch these errors as well.) Meanwhile Julia simply hasn’t reached the same level of professionalism. (I’m not taking shots at anyone in particular here.)

[Additionally there’s the whole #4600 - FromFile.jl - include problem that IMO hinders readability. But I’ve spoken about that extensively before, and it seems to be controversial here, so I’ll skip any further mention of that.]

(D.3) Math variable names

APIs like Optimiser(learning_rate=...) are simply more readable than those like Optimiser(η=...). (I suspect some here will disagree with me on this. After all, APL exists.)

(E) Painful array syntax

  • Julia makes a distinction between A[1] and A[1, :];
  • The need to put @view everywhere is annoying;
  • The need for selectdim over something (NumPy-style) like A[..., 1, :] reduces readability.
  • The lack of a built-in stack function is annoying (to the extent that Flux provides one!)

Array manipulation is such an important part of ML, and these really hinder usability/readability. One gets there eventually, of course, but my PyTorch/JAX code is simply prettier to read, and to understand.

(F) No built-in/ubiquitous way to catch passing arrays of the wrong shape; a very common error. (At least that I know of.)


  • JAX probably has the best usable offering for this. Incorrect shapes can be caught during jit compilation using an assert statement. The only downside is that actually do is very unusual (not even close to culturally ubiquitous), probably because of the need for extra code.
  • Hasktorch and Dex encode the entirety of an array’s shape into its type. Huge amounts of safety, ubiquitously. Only downside here is that both are experimental research projects.
  • PyTorch has torchtyping, which provides runtime or test-time shape checking, essentially as part of the type system.

My “dream come true” in this regard would be something with the safety of the Rust compiler and the array types of TorchTyping.

Question 3

My experience has been that all of PyTorch/JAX/Julia are fast enough. I don’t really find myself caring about speed differences between them, and will pick a tool based on other considerations. (Primarily those listed above.)

Question 4

Maybe not an “experiment” in the sense you mean, but – more Q&As like this one, in particular at other venues where there’s likely to be more folks that have (either just for a project or more broadly) decided against Julia.

Question 5

Best case argument:

Imagine working in an environment that has both the elegance of JAX (Julia has arrays, a jit compiler, and vmap all built-in to the language!) and the usability of PyTorch (Julia is a language, not a DSL!) Julia still has issues to fix, but come and help pitch in if this is a dream you want to see become reality.

Impactful contributions:

  • Static compilation. Julia’s deployment story is simply nonexistent, and IMO this sharply limits its commercial applicability.
  • Better autodifferentiation. I know there’s ongoing work in this space (i.e. Diffractor.jl, which I haven’t tried yet) but so far IMO Julia hasn’t yet caught up to PyTorch/JAX on this front.
  • Fixing all of the negatives I raised above. Right now none of those are issues suffered by the major Python alternatives.

Question 6

As Q5.

Question 7

What packages? Right now, I’m tending to reach for JAX, Equinox, Optax (all in Python).

Why those packages? They provide the best trade-off between speed/usability for me right now.

What do I wish existed? Solutions to the above problems with Julia ML. I find that I really like the Julia language, but its ML ecosystem problems hold me back.


I think all of your comments were fair (although I don’t agree with everything, for example the need to opting into @view is because Julia tends to be safe by default and a view when slicing would not be so), but I don’t quite get this point:

Why is that a problem? A[1] accesses the element with index 1, A[1, :] access the row with index 1, I fail to see what’s the problem.



That’s safe-by-default for functionality but not safe-by-default for speed.

Pragmatically speaking, I find that taking views is very common, whilst making a copy is unusual.

Special syntax for implicitly re-striding an array just seems a little odd to me. I don’t think I’ve ever done it deliberately in either Python or Julia.


On the contrary I always get a feeling of unease when leaving out indices in numpy, although I’ll admit that I sometimes like how the code turns out. One thing I dislike about it in numpy is that it seems kind of arbitrary why it means the same as A[1, :] rather than A[:, 1]. If the answer is that it’s natural based on the (default) strides, notice that it then would be natural to have the opposite meaning in Julia, which probably would drive everybody nuts. (Obviously it’s not possible to change this without removing the linear indexing feature from Julia.)


I guess you’re not particularly eager to get stuck on this small side issue, but it seems like Julia has completely general syntax (with a very consistent rule about the dimensionality of the indices vs the output), while the Python syntax seems special and odd.

Otherwise, thanks for an interesting post.


For what it’s worth I’m fairly certain that that numpy syntax is internally consistent as well.


It is used it in a fair amount of Julia code. I’ve used it. It’s handy, and it’s clear what’s happening.

But, obviously, this is the least important of your criticisms.

I am trying to find a robust way to deploy Julia inside a python project. People have put a ton of work into Julia-Python interop in general. But, the parts closer to deployment are not there yet. It would be a great place for people to lend a hand. (This has been known for long time.) But, it has to compete with the other pressing issues you mentioned.


@ChrisRackauckas on the topic of machine learning and E-Graphs, how do you view E-Graphs in comparison to the work of the PyTorch developers on TorchDynamo. While E-Graph does seem to have wider-reaching goals, especially for SciML, than TorchDynamo, I would be intrigued to hear you see the two of them matching up?

Could E-Graph fulfill a similar role than TorchDynamo for the ML-Ecosystem in Julia?

1 Like

It’s not the same or similar thing as the E-graph, but instead it’s similar to the interfaces the E-graphs are acting on. Maybe the easiest way to describe it by saying what is the same or similar. The Python bytecode is like “the Julia IR”. Of course, as an optimizing compiler, there isn’t a singular IR, instead there are stages: the untyped IR, the typed IR, and the LLVM IR. Cassette and IRTools, the tools on which Zygote.jl was built (some notable others are AutoPreallocation.jl, SparsityDetection.jl, etc.), are probably the most similar to TorchDynamo in that on untyped syntactic IR it is a tool that transforms to another untyped syntactic IR.

It turns out that for Julia this was a bad idea because (a) the meaning of code can depend (heavily) on types, and (b) this is before compiler optimizations, and so mixing compiler optimizations with automatic differentiation is impossible. Thus Julia v1.7 added an AbstractInterpreter interface to Julia Base itself for acting on typed IR, which is then used by packages like EscapeAnalysis.jl and Diffractor.jl to write compiler passes on typed IR. And of course LLVM IR has standard interpretation techniques along with Enzyme.jl which is an AD written on LLVM IR.

So TorchDynamo is probably most similar to Cassette/IRTools, but you could also say it’s like AbstractInterpreter in that it’s acting on “the true IR of Python”, where the true IR of Julia is typed when it has all of its information while in Python it is not. But this story is why Zygote has its compile-time issues, higher order AD issues, and why all of the tooling is moving to not just a new AD tool but an entirely different IR target and compiler tool stack (note this doesn’t imply that will happen to TorchDynamo, unless they start rewriting their AD to be source-to-source on Python bytecode, but there’s precedent of that in tangent which didn’t find a nice home). Note that these tools aren’t just for AD. For example, there are PRs to Julia’s Base which are automatically analyzing loops and removing repeated allocations of immutable arrays where they are written using the AbstractInterpreter compiler plugin interface.

So that still doesn’t answer how the heck E-graphs comes into the story because I haven’t described how you write a compiler pass. It doesn’t matter what level of IR you’re on, it’s basically just a function IR->IR. So where in their blog post they say “just add code here”

def custom_compiler(graph: torch.fx.GraphModule) → Callable:
    # do cool compiler optimizations here
    return graph.forward
with torchdynamo.optimize(custom_compiler):
    # any PyTorch code
    # custom_compiler() is called to optimize extracted fragments
    # should reach a fixed point where nothing new is compiled
# Optionally:
    # any PyTorch code
    # previosly compiled artifacts are reused
    # provides a quiescence guarantee, without compiles

Well, that’s true in any of these systems, just like in macros. But if you’ve ever written a macro, you’ll know that walking expression graphs is a tedious process to get correct. Wouldn’t it be nice if compiler optimizations for mathematical ideas could be expressed mathematically, and the associated compiler pass could be generated? It turns out that all Symbolics tooling really is is just tooling that performs rewrites on some IR. So Symbolics.jl has an IR that uses SymbolicUtils.jl’s rewriters and MetaTheory.jl’s E-graphs to transform symbolic IR → symbolic IR, but what we have done is made those rewrite tools generic to the IR and boom now it’s a compiler optimization pass generator.

That means you can say define an E-graph that acts on Julia typed IR and spits out the typed IR with the desired simplifications described mathematically. This is what we mean by “democratization of writing compiler passes”: we are trying to use this to build a system so that people who want to add a new linear algebra simplification pass to the Julia typed IR do not need to learn all of the details of the AbstractInterpreter and Julia Typed IR definition, and instead just write a few mathematical equalities and boom it generates a compiler pass which then generates the transformed IR. So think of the E-graphs as replacing this requirement that someone writes a function like def custom_compiler(graph: torch.fx.GraphModule) → Callable: that digs through some expression graph. Instead you just write

Man, this came out longer than expected. But since it describes why Zygote is being replaced with Diffractor and Enzyme I guess it’s a useful description for many other reasons than the original question :sweat_smile:


Thanks for the shout-out on the CTPG paper! I whole-heartedly agree that the paper would not have been possible without Julia and DifferentialEquations.jl. It’s cool to see what can be done when you break out of the usual static-graph mode. Thanks and kudos to @ChrisRackauckas!

I work in ML, largely in Python, but I have a soft spot for Julia as well. I think @patrick-kidger’s response summarizes things very well. I’ll just chip in a few of my own experiences/thoughts:

What does ML even mean?

There are so many different types of models/problems/architectures these days that it’s worth pointing out that there’s a big difference between “conventional deep learning” – transformers, convnets, large models – and “other” more obscure models – differentiable physics, neural ODEs, implicit models, etc. So far I think Julia is doing better in the “other” category.

It’s worth noting that the requirements for “conventional” vs “other” can be drastically different. Everything from model parallelism to compute architecture to float32 vs float64.


Compilation speed is entirely irrelevant (cf. jax). What matters at the end of the day is iterations/second. Right now, JAX/XLA seem to have Julia beat in the “conventional large model” space since they have all kinds of optimizations for linear algebra, specific kernels, and TPUs. At this point just about every last drop of performance has been squeezed out of pytorch/TF/jax in the “conventional” large models space.

That being said, I am extremely bullish on the MetaTheory.jl line of work with e-graph based optimization. Ultimately I think this is a superior design than anything in the competition. But the devil will be in the details of making it production-ready esp. on GPUs/TPUs.


Like @patrick-kidger, I have been bit by incorrect gradient bugs in Zygote/ReverseDiff.jl. This cost me weeks of my life and has thoroughly shaken my confidence in the entire Julia AD landscape. As a result I now avoid using Julia AD frameworks if I can. At minimum, I cross-check all of their results against JAX… at which point I might as well just use the JAX implementation. (Excited to check out Diffractor.jl when it’s ready though!)

In all my years of working with PyTorch/TF/JAX I have not once encountered an incorrect gradient bug.

Ecosystem and library scope

I really wish there was just something like JAX in Julia. Flux.jl is too high-level for me most of the time. Zygote is often too low-level. I like the idea of source-to-source AD though. Maybe we just need new frameworks on top of Zygote/Diffractor to spring up? I don’t know. I expect that solutions here will emerge naturally as more investment is made in ML in Julia and people bump into the limitations of existing tooling…

I’m optimistic for the future of ML in Julia. I really am. For me personally, it’s not ready for what I need it to do just yet. But I’m optimistic that this may change over time.


Thanks for the very thoughtful post Patrick, and nice to see you around. Some thoughts on the above:

I’d argue this applies to most non-science/numerics projects and ~20-30% of scientific python code, but there is a long, long tail of projects that use no linting or any kind of static analysis. These projects tend to make many of the same faux pas you mentioned.

I think this speaks to three things:

  1. The benefits of centralization in the Python ecosystem. A majority users doing data-y stuff can get away with the that top 20-30%.
  2. The relative size/resourcing of both ecosystems. Code quality may well be better for Julia and Python codebases of the same popularity, but if we look at relative in-ecosystem popularity then your point probably holds. I don’t have a good intuition on how much we should weight those two disparate perspectives. For example, the aforementioned exemplary Python projects have multiple magnitudes more engineering time/money/infra to work with, and trying to replicate their quality without those is rather unrealistic.
  3. A need to get more automated tooling running in the Julia ecosystem. I want to say DocumentLint has been used in CI, but it’s still primarily an IDE thing. JET.jl leaves basically every Python type checker in the dust, but its relative novelty means that adoption is still low.

This is possible at runtime with libraries like GitHub - invenia/NamedDims.jl: For working with dimensions of arrays by name and potentially statically with JET + named array libraries. The biggest challenge I see (one you’re likely familiar with developing torchtyping) is adoption + standardization. Guido has been leading an effort on the Python side, so seriously exploring avenues like could be fruitful here.


Thanks @Samuel_Ainsworth and @patrick-kidger for your frank thoughts! It’s really important to get this sort of feedback.

How long ago were both of you getting incorrect gradients? Were these errors on recent versions of zygote? After the chainrules switch?


If you’ll allow me to start from the end first:

This is not well documented and ought to be so (PRs welcome if anyone is interested), but Flux is really an amalgamation of different sub libraries:

  1. An AD (Zygote)
  2. A set of ML kernels (NNlib)
  3. A module system (Functors.jl)
  4. A set of layers, optimizers and training loop helpers (Flux itself)

Using just #1 and #2 is equivalent to torch.nn.functional. Using 1, 2, and 3 gets you a JAX equivalent. The plan is to move optimizers out from #4 into a separate package (see Optimisers.jl) so that you can use it just like JAX users use Optax. This kind of shared infrastructure is already being exploited in the ecosystem: Knet uses NNlib (NNlib dev is a collaboration between Knet and Flux) and offers a “lower level” interface you may be interested in, while Avalon.jl uses NNlib + Functors for a more PyTorch-esque framework.

Now to the broader, more philosophical point. I also use 100% Python for my own work, and the dynamics/motivation there are very similar to what you’ve described. Though not worded particularly pleasantly, I think this HN comment summarizes the struggle well:

My impression from your comment is that you don’t care that much about “standard” ML users. As a “standard” ML user (pytorch/jax), and a potential Julia user in the future, this is not what I like to hear.

Now, there have been very some very good points made here and on different forums that trying to take the Python ML juggernaut on in its own territory is at best aspirational (E: after reading Chris’ response, the original more forceful “fools errand” would’ve been more appropriate :stuck_out_tongue: ). What I don’t think has happened is saying the “quiet part out loud” following the logical conclusion of that. Of course the Julia community is not a monolith and there will be divergent opinions on how to approach ecosystem development, but folks like the aforementioned HN commenter are looking for a clearer statement. That is, where do we fall between the two extremes of “novel architectures/approaches are the only way to go, if they do it well then we shouldn’t bother” to “Julia ML should be #1 on everything”? And depending on the vision, what are some concrete steps that can be taken to support it?

Edit: to make sure I’m not underselling or misrepresenting things, there are some great and very clear roadmaps for parts of the ML space already. SciML and advanced AD come to mind. The question above is about the complement: what should be put into the “don’t expect anything big here unless you’re willing to help develop or fund it” bucket?


+1 to calling it “conventional” ML (or some other name), since there is already an important programming language called Standard ML (meta-language) that Julia packages take features from.

Yes, one thing to mention is that the Julia community is large and not a monolith and so there are many people developing these tools, all with their own reasons and aspirations. While there are some institutions that tend to have more of the developers for AD and ML libraries (specifically the Julia Lab and Julia Computing), those entities are large and not monoliths themselves. Even at the Julia Lab, I have no control over why people work on these problems, rather I just work with the students and research software engineers to guide them towards successful projects. Many people are doing it as ML for ML’s sake, and that’s fine.

But I think everyone should just be honest and clear as to some of the technical aspects and how they relate to the higher level decisions that have developed such large labs around this topic.

trying to take the Python ML juggernaut on in its own territory is at best aspirational

No, that’s an understatement. Let’s make it absolutely clear: there is nothing in the technical approach of differentiable programming that will make “conventional ML” faster. Period. A perfect Zygote or Diffractor will not make matrix multiplication kernels faster, it will not make convolutional kernels faster, and will not make faster Transformer kernels. For large “big data” conventional machine learning, calls to the kernels are on the order of tens to hundreds of seconds. The AD overhead of a slow AD like PyTorch or even just AutoGrad is in the miliseconds per operation. A source-to-source AD that cuts that down to close to zero is not getting even a 1% gain in those applications. Source-to-source AD is a much larger and harder project which trades the applicability to full dynamism and lower overhead (+ JIT compilation of all reverse paths) for a lot of added complexity. Conventional ML models like transformers do not use this dynamism. Those models do not have to worry about this overhead. The current AD work will not magically some day give you something that will be compelling to conventional ML users to make that pack up and switch from Python. If that was the purpose of those projects, then those projects would be an extremely dumb idea. Why build a brand new multi-million dollar stadium from your kid’s elementary school football team? It’s not a fit-for-purpose idea, and it will actually hold the Julia ecosystem back for a bit in this domain because of the added complexity.

Maybe having full language support will make some ergonomic gains, like it will integrate with the profiler and debuggers better than DSLs generally do, and if someone happens to write a model in the “wrong” way it could play nicer than say something like Jax where if you write something that isn’t functional and pure :man_shrugging: incorrect gradients can occur. But we’re talking minor gains at the end of the day for those applications.

But let’s dig even deeper. Zygote’s purpose was to not unroll loops so that the AD could JIT compile loopy code with small kernels. That’s a very nice improvement for domains that need loopy code with small kernels. You can expect some pretty good performance gains, and you should choose Zygote if that’s your domain. Conventional ML is not in that domain. :man_shrugging: sorry. This emerging whole SciML domain happened to fit that domain and that’s how it found a home there which launched the organization and such. With that lens, it should be no surprise that in conventional ML Julia did not capture the whole audience whereas in SciML it became a big chunk of the (still rather small) field. It’s not random, and it’s not just sweat and grit, there’s real technical reasons behind it that you shouldn’t just gloss over.

Diffractor.jl’s driving emphasis was a category theoretic formulation for higher order derivatives. That gives you some massive speedups if you’re calculating third or fourth derivatives. But in conventional ML, who’s doing that? People don’t take Hessians of neural networks, let alone anything higher. Yes, there will be some spillover effects for how this improves conventional ML cases because of changing the target towards typed IR (potential compile-time improvements, maintainability, etc.). But flipping the Diffractor switch won’t be a day where Flux is suddenly a whole lot better for conventional ML. The reason for this kind of tool is applications like physics-informed neural networks which routinely take 3rd order derivatives and above. That’s the kind of application that funded it (specifically for use in NeuralPDE.jl). That’s a growing field, enough so that the NVIDIA CEO keeps mentioning physics-informed neural networks, and that’s an area where this kind of tool will cause a substantially noticeable difference. But that’s not NLP or image processing with convnets and transformers. For those cases, Diffractor would be a very hard project to get little gains, it would make no sense. If the purpose of Diffractor was those domains, it would be a bad idea.

So let’s refocus a little. Let’s say your goal is to improve conventional ML. How would you do it? Here’s a few things that come to mind for me:

  1. You could focus a project on conventional ML researchers by making it easier to develop faster kernels. This would help people out of the “ML is stuck in a rut” problem where better ideas can be slower than worse ideas simply because of how much the standard kernels have been optimized. If you want to do this, you should develop an AD that is really good at differentiating compute kernels. Zygote and Diffractor are not the tools for this, Enzyme.jl is. See the paper for generating adjoints of GPU kernels as an example. Or you could develop tools like LoopVectorization.jl that are instead targeted to GPUs. KernelAbstractions.jl.
  2. You could focus a project on making it easier to capture more high level kernel fusions to optimize the kernel-centric code. That’s the e-graphs projects, and that’s what the folks at Google are doing with XLA. That’s what MLIR is aiming to do.
  3. You could focus a project on making it easier to do distributed multi-GPU training. The ergonomics here are still rather difficult, with with TensorFlow/XLA. Easy installation and running it on local compute clusters. DaggerFlux is probably the closest project we have to this other than XLA.jl
  4. You could focus on writing faster GPU kernels for specific tasks.
  5. You could make packages with experimental APIs to improve the ergonomics of conventional training workflows. Integrate some automation in there. Automatic MLops? ML libraries without implicit global parameter references?
  6. You could, instead of waiting for Zygote and Diffractor to be “complete”, skip ahead and do ML on small DSLs. DSLs will always be easier to optimize given their constrained nature. Yota.jl is a great example of this. It uses a tracer, Ghost.jl, to get a simpler IR and does some nice things on that.

Noticeably absent from that list are the current ADs and differentiable programming work. That will do almost nothing for the conventional ML domain except maybe, just maybe, a few ergonomic improvements when everything works out. There are much better projects to work on if conventional ML was the goal. But for me and large parts of the Julia Lab, conventional ML is not the goal, which is why there is so much work and publications in differentiable programming tools. Hopefully this line of reasoning makes it as clear as daylight.


ML in Julia has a bright future, and is currently very strong in certain areas. I am constantly impressed by the intelligence of those working in the Julia AD space. Everything is possible in Julia. In fact everything is trivial in Julia, if you are very clever. This is the current problem:

ML in Julia requires high existing knowledge or a lot of time searching/doing trial and error.

Previous answers have discussed the lack of technical limitations to moving on par with PyTorch/Jax for general deep learning, but there are other important factors that drive adoption: thorough documentation, blogs, useful error messages, stability, and a vague feeling of “trust”. It can be tempting to think that these things follow naturally from the technical possibilities, but they are often driven by that special type of contributor who prefers taking off the rough edges to adding new ones.

Getting the Julia ML ecosystem to the scale of PyTorch requires drawing in these type of contributions, and without the luxury of big tech support or resources. As others have said, it’s not the primary focus of many Julia developers (myself included).

On a personal level I am currently in the wild woods of developing novel differentiable algorithms in Julia. I love it, but it has come with a constant stream of cryptic errors, lack of features and incorrect gradients. Sometimes I long for the warm blanket that is PyTorch.

Julia should seek to become that warm blanket. It already has for general scientific programming.


I would like to to give my 2c about the topic.

First, I will speak from my point of view considering both typical Machine Learning, dominated by scikit-learn in Python, and Deep Learning, dominated by TensorFlow and PyTorch, and the more recent Flax.

  1. Where does ML in Julia shine today?

Well, it is is difficult question because it depends on the compared ecosystem (scikit-learn, TensorFlow, PyTorch, caret in R, …).

In my opinion, Julia shines in the fact that the packages/libraries does not have to implement too much, enforcing the compatibility between them. For instance, MLJFlux, or the fact that all of them could work not only with DataFrames but with other structures that implements Tables.jl interface.

It does not shines in velocity, because in my opinion, TensorFlow, PyTorch, Scikit-learn JAX, are more mature libraries and their implementations use C/C++ and GPU. However, it can be more flexible, because it is in the same language.

Also, it shines in the simplicity of the implementations, you can read Flux, for instance, and understand a lot of it.

  1. Julia ML ecosystem is currently inferior in features. For instance, in ML, for tackling imbalance categories, advances data transformations (like discretization, …), . Also, in MLJ the required time in packages implementing the models takes more than the implementation using Scikit-learn (but in that case the message errors are a lot worse). In DL, Flux is more general and it has a lot less features than TF/PyTorch (reading files, segmentation, preprocessing, …). There are some work, like FastAI, but it is still in work.

3 and 4. I have not strong information about the performance, there is not really good benchmarks (or at least I do not know them).

  1. I think it should be improve the documentation, and improve detected bugs or required features to be at the level of other packages.

  2. Because it can be a great alternative, that could be a lot better without putting a lot of resources in it.

  3. I usually used R more for data processing, and pandas in Python but now I use a lot more Julia. I have use for ML the framework in Julia MLJ, however, for imbalance, tuning, I usually use more Python. Nowadays, I use TF or PyTorch for DL, but I hope to be able to replicate the work in Julia, but until recently, with FastAI, there was a lot of preprocessing not implemented.