State of machine learning in Julia

You might want to track the progress on ONNX.jl - when it’s ready, you should be able to load any of the thousands of the pretrained models to Julia (e.g. see instructions for HuggingFace models) as well as run Julia models on mobile. We are not there yet, but some simple examples like Resnet are already functional.


One obstacle though, and, in my opinion, the biggest challenge for Julia, is that it provides an open and extensible ecosystem for machine learning. For example, JAX has ~200 operators, ONNX - ~250, PyTorch - around 2000. Every operator there is guaranteed to support backpropagation, work on GPU/TPU, have good documentation, etc. Sometimes it’s hard to support all of them, yet possible.

In Julia, we essentially have an infinite list of operators scattered over multiple packages. There are no formal requirements for adding new functions, and authors usually only cover their own needs. For instance, someone adding a new ChainRules.rrule() may not check if it works with CuArrays, somebody creating a new closure-based layer may not think about exporting it to ONNX, etc.

For a couple of days I’ve been thinking of a curated list of high-quality operators and language features with particular guarantees, e.g.:

  • support autodiff (e.g. via ChainRules.rrule)
  • support GPU
  • support low-precision element types (Float32, Float16)
  • have docstrings & be discoverable (e.g. be listed in some popular document)
  • have performance tests
  • support import/export to/from ONNX, XLA, etc.

We can then create automatic tests to ensure the quality of the listed features. Tests themselves don’t have to be exhaustive - these things must be tested in their packages anyway - instead, they will show the level of maturity of the ecosystem and highlight potential issues.

67 Likes

To me that’s just a reframing of the same problem. If you’re restricted to a smaller set of centralized blessed ops, then you might as well use pytorch.

I ideally there would be a way to typecheck functions or have an API that enforces semantics which are more easily optimizable so that an ecosystem could be both distributed and fast/correct with AD. Right now it’s a bit magical. Type stability can be hard enough, but now there are ever shifting inscrutable code patterns and corner cases that people have to worry about for GPU and AD, so of course they won’t compose well with flux.

The easier it is technically, the less social coordination has to happen.

BTW, ONNX.jl is amazing. Whatever happens in Julia, it’s a huge boon to have access to all those python models. Thanks for your work there!

Also have my eye on yota.jl

1 Like

Having a list of operators with certain guarantees doesn’t restrict you in any way. Take mutation for example. Most AD frameworks either don’t favor or explicitly don’t support it (e.g. see the long-standing issue in Zygote). However, it doen’t mean you can’t use mutation in non-AD code (e.g. during data preprocessing) or hide it behind AD primitives (e.g. via pure rrules).

This is in contrast to, say, JAX that has this restriction on the framework level. If you need a mutable array there, you have to fall back to NumPy (without GPU support) or some other library, ending up in a weird mix of technologies and endless conversions.

8 Likes

I think having a list of blessed ops is a good stopgap for now, I agree.

And yea, data handling and non differentiable simulation code etc is much more pleasant in Julia

1 Like

For a couple of days I’ve been thinking of a curated list of high-quality operators and language features with particular guarantees

I fully support such idea! Moreover, I would add that the same should be done with operators that are known to be buggy or problematic (i.e. with open/known issues). From my experience using Zygote, this would have made my life so much easier. When you debug AD stuff in Julia, it is often very tricky to isolate which operator/line is crashing. With such a list, it would be so much easier to narrow down which are the potentially problematic operators/functions and directly go debug those. And of course, in order to make it less of a pain in the ass to maintain, only those tougher and long-standing bugs could be included. For new issues which have a potential easy fix in sight, they wouldn’t need to be added there for simplicity’s sake.

Such lists could be posted in the docs of each library and merged together in JuliaDiff to have a nice overview.

1 Like

May be rely more on https://github.com/JuliaDiff/ChainRulesTestUtils.jl .

3 Likes

I’ve just created HQDL.jl to start tracking the list of such operators. The package provides a macro @inspect (and a very similar @analyze, see their docstrings for the difference) that:

  • checks if a function is even callable
  • runs ChainRulesTestUtils.test_rrule() with Array and CuArray types, Float64 and Float32 precision
  • checks if the function has docstring

The first report obviously has many false positives, but it uncovers a few interesting observations. For example, many broadcasted activation functions from NNlib fail on test_rrule() for unclear reasons. Many other functions don’t define rrule() for broadcasting and thus rely on AD to handle it. Usually, AD is able to handle them as long as rrule() is defined for the element-wise function.

If this initial effort looks interesting for the community, I’ll add tests/benchmarks on several popular AD frameworks, interop with ONNX, as well as manual notes, e.g. mentions in the docs, links to known issues, etc. Also, so far I only added a handful of operators, so many more are to come.

54 Likes

This looks awesome.

1 Like

Amazing! What about adding a performance analysis JET.jl Pass? https://github.com/aviatesk/JET.jl/blob/84aea0c97ecc83f955ab2fde455dce05d4d1736f/docs/src/optanalysis.md

1 Like

Great work and really good overview!

it seems similar to the tim holy efforts on invalidation, in the sense that you need a tool to diagnose what’s wrong before start fixing things

7 Likes

Dont we need a big library “a la Plots.jl” with different backends like

  • flux
  • knet
  • yota

?

1 Like

It might be helpful if there were a table showing each library’s name for common functions.

8 Likes

Right after I say this :stuck_out_tongue: https://twitter.com/aleks_madry/status/1483523047273512978… Would be interesting to try to replicate ffcv in Julia with MetaTheory.jl and native JIT compilation. That would really turn heads if it was actually faster!

5 Likes

Crazy speed improvement on their end. Gotta check this out :relaxed:.

This quote makes me happy.
kwargs > pargs all day every day.

1 Like

In case you haven’t seen it,

7 Likes

Oh I know, I’m co-author on the paper it’s being written for :sweat_smile:. It’ll get renamed and stuff first though, so it’s ready for a slack share but I wouldn’t throw it to twitter and everything yet.

11 Likes

Just chiming in to see if the ONNX.jl progress is coming along. It’d be a huge boost to the entire Julia community I feel. Making the transition a lot easier to switch.

We’ve tagged a new release of ONNX.jl recently. Although it doesn’t have direct integration with Flux yet, it can already load and execute a number of popular operators.

At the moment, the main issue with further progress is the lack of workforce - around 80% of pending operators are trivial to implement. But maintainers typically have 4-6 other projects to care about, so number_of_hands * fraction_of_time is still pretty small.

(Perhaps I need to add a section for new contributors to the README :thinking: - Edit: done)

11 Likes