What lessons could Julia's autodiff ecosystem learn from Stan's TinyGrad?

I agree it’s limited, but it’s an AD. A lot of literature on AD describes similarly minimal implementations, so while we can accuse them of not being useful I don’t think it’s fair to use # of existing ops as a purity test. But that’s the less important point: the more substantive one is that Greenspun’s tenth law basically applies to ADs in ML libraries: try to avoid it all you might, you’ll inevitably invent a half-baked one yourself.

But back to the thread, this is not just about Tinygrad but also Stan’s AD, which is more general purpose and has come up many times before in threads about PPL performance. In both cases, one can feasibly point to Julia libraries which are more featureful. Yet somehow we devote proportionally less resources to said libraries (if any at all) and have instead put all our eggs in the basket of experimental new tech again and again? I believe it’s worth exploring why this happens and what lessons we can take from it to improve the AD ecosystem and avoid similar problems arising in the future.

9 Likes

To me every Python AD feels like a walled garden as each comes with its own incompatible tensor types and slightly different (sub)set of supported operations. Some of these gardens are just very large by now, e.g., PyTorch is aiming to reduce its set of operators from 2000+ to around 250 core ones.
In Julia, I can just write my model – possibly combining several libraries – and then try different ADs on it. Especially AbstractDifferentiation makes it very easy to change the AD backend. In my experience, ForwardDiff and ReverseDiff have worked quite reliable (even in some crazy use cases that I would not have tried in Python to begin with). I did face several issues with Zygote though, either hitting some limitations or silently wrong gradients, so I tend to avoid it for now.

1 Like

You can say many things about AD in Julia, but “putting all our eggs in one basket” isn’t really one of them :wink:. And there’s clearly a lot of maintenance effort that has gone into existing systems (e.g. ReverseDiff.jl, Zygote.jl, & ChainRules) while people work on successors (e.g. Enzyme.jl).

But ultimately there is no authority here directing the allocation of resources, so it’s never clear to me what is meant by “we” in discussions like this. If someone with deep pockets wants to pour resources into a particular AD, they are welcome to, but you can’t stop research groups from trying to invent better AD systems, and it is a healthy sign that Julia is an attractive language for this kind of research.

I’m also not completely clear on what AD systems you are holding out as the example worth emulating. tinygrad’s AD is already far surpassed by what we have now, Jax’s walled garden is not something we can emulate without vastly more resources. stan::math?

8 Likes

They finished doing that earlier this year with PyTorch 2.0; I think they’re down to just over 200 now.

That seems like the opposite extreme to Zygote. I find ReverseDiff.jl general enough for most use cases–definitely more general than JAX–and I’m happy with that middle ground. (Or at least I would be, if ReverseDiff had basic features like working on GPU.) Or we could try and improve ecosystem support and maintenance for Yota.jl, which in theory should hit 95% of Zygote’s use cases with 5% of its engineering effort.

I don’t think there’s anything wrong with trying to build something more advanced than JAX. I just think it’s a mistake to try that before we’ve even managed to match the features of JAX. Trying to wait around another 5 years for advanced autodiff before we can use basic autodiff is going to kill Julia (assuming it hasn’t already).

6 Likes

No, Chebyshev polynomials were already there in Boost, and someone actually did the work to incorporate them into Stan, ran their one model or something, and then didn’t push to get it merged into Stan because they didn’t need it anymore.

What’s wrong here isn’t that Stan can’t do Chebyshev polynomials, it’s that Stan autodiff is a walled garden designed to support the mc-stan compiler + sampler ecosystem. To add some feature to it you need “permission” or “buy-in” or to maintain your own fork, and you need to be a C++ programmer who understands deeply templated Stan related stuff.

The idea in Julia is because there’s ApproxFun.jl and ReverseDiff then automatically you have differentiable Chebyshev polynomial fits in Turing.jl etc.

I want to echo how good ForwardDiff and ReverseDiff are, and that ReverseDiff is in my opinion better than Stan precisely because you can differentiate through packages like ApproxFun.jl or whatever else we might have.

I started using Stan back immediately as soon as it was available, but by 2019 or so I was moving on to Julia and I haven’t looked back. I’d rather program in Julia than Stan a thousand times over, it’s just more general purpose. Also it’s always worked. I haven’t had any issues with wrong gradients etc. (Using ForwardDiff or ReverseDiff in Turing.jl)

A final thing I’d mention is that I believe I have an excellent derivative free sampling scheme for use with Turing.jl but I don’t quite understand enough about how to implement samplers for Turing.jl so I think there’s room for making it easier to bring people into developing infrastructure. I want to try to help improve Turing’s documentation related to all this, but have been stymied a bit by unrelated issues.

5 Likes

As someone in the trenches here, I would respectfully disagree. A lot of effort has gone in, but not nearly enough. At this point it’s mostly :see_no_evil: when it comes to Zygote bugs. ReverseDiff seems to have done better, but for whatever reason Zygote keeps being recommended instead. And Tracker is left in (shamelessly stealing this from @Krastanov) the “dustbin of history”.

That’s easy: downstream and (end) users. AIUI one of the reasons ReverseDiff still gets contributions is because the Turing team and others contribute back. Zygote only gets any attention because it and Flux are unfortunately chained at the hip.

This is why I had my disclaimer about Enzyme above :slight_smile: . It’s great that we have experimentation in this space and I absolutely agree with your statement. The problem arises when complaints with the status quo are met with “oh just wait for the better thing to be done” and no effort is made to even explore if what we have now can be improved to at least help bridge the gap. I understand such reactions are mostly knee-jerk and not intentionally meant to be defeatist, but their effect is simultaneously demotivating work on existing libraries while raising expectations ever higher for new ones. Lest we forget, this is what led to many people getting burned on Diffractor before it was re-scoped.

20 Likes

Good to know, thanks.

There is still an important difference though – maybe also a reason why ADs in Python appear more bullet-proof. Let’s say I want to use some special function in JAX, PyTorch etc. Then, it either is supported (and most often also has an AD-rule defined for that matter) or not.
Instead, in Julia the function can be implemented somewhere already, but no chain rule is defined for it, e.g., via ChainRules. Nevertheless, an AD will happily step into the function and try to differentiate whatever it finds in there, i.e., while loops, scalar operations etc. In the end, this might or might not work as expected.
Here is an example:

import torch
x = torch.tensor([1.0], requires_grad = True)
y = torch.special.i0(x)
y.backward()
x.grad  # prints tensor([0.5652])

z = torch.special.???  # Hankel function not available
using SpecialFunctions
import AbstractDifferentiation as AD
import ForwardDiff, ReverseDiff
AD.derivative(AD.ReverseDiffBackend(), x -> besseli(0, x), 1.0)  # prints (0.565159103992485,)
AD.derivative(AD.ForwardDiffBackend(), x -> real(besselh(0, x)), 1.0)  # prints (-0.4400505857449335,), but fails with ReverseDiff due to StackOverflowError
# Checking the methods reveals that frule and rrule are defined for besseli, but not besselh
4 Likes

This is also what happens with Python ADs. If someone wrote an equivalent to GitHub - JuliaMath/Bessels.jl: Bessel functions for real arguments and orders which used only PyTorch or JAX operations, you’d see the same behaviour despite them not having defined a custom rule for it. That’s what separates AD from more manual systems for differentiating functions.

So the difference lies elsewhere, but where? torch.Tensor — PyTorch 2.1 documentation offers some ideas. Notice how there’s a very limited number of differentiable types in PyTorch? This would be like if Julia ADs only accepted Array{<:Union{BlasFloat,Integer}} and Julia disallowed adding new subtypes of BlasFloat and Integer. No array wrappers (those are handled as runtime properties of torch.Tensor) or fancy custom number types. The smaller number of possible “primitive” operations allowed does help too, but just as important is that only relatively well-behaved ones are. No push!, delete!, etc, which can cause headaches for Julia AD maintainers.

7 Likes

before we can use basic autodiff is going to kill Julia (assuming it hasn’t already).

I’ve avoided engaging in this discussion so far, but people use JuMP’s autodiff in Julia all the time to solve some very large and complicated use-cases. “basic autodiff” should not need to mean mutating special function support on a GPU.

We strictly follows the “do few things right” instead of trying to support a wide range of operators: Overview · MathOptInterface

julia> import MathOptInterface as MOI

julia> function eval_univariate(f, x)
           model = MOI.Nonlinear.Model()
           variable = MOI.VariableIndex(1)
           MOI.Nonlinear.set_objective(model, f(variable))
           evaluator = MOI.Nonlinear.Evaluator(
               model,
               MOI.Nonlinear.SparseReverseMode(),
               [variable],
           )
           MOI.initialize(evaluator, [:Grad])
           g = [NaN]
           MOI.eval_objective_gradient(evaluator, g, [x])
           return g[1]
       end
eval_univariate (generic function with 1 method)

julia> eval_univariate(x -> :(besselj0($x)), 1.0)
-0.4400505857449335

julia> eval_univariate(x -> :(real(besseli(0, $x))), 1.0)
ERROR: MathOptInterface.UnsupportedNonlinearOperator: The nonlinear operator `:real` is not supported by the model.

The problem is that it is a walled garden focused on nonlinear optimization, which is why it hasn’t been adopted by projects outside of JuMP…

9 Likes

But this is precisely my point, because Python’s AD will only work if the function is (re)written to use “only PyTorch or JAX operations”. I.e., someone has to implement the function in the walled AD ecosystem and probably tests it against it. In contrast, in Julia the function was just written in base Julia without AD in mind and, in particular, was never tested against any AD. It might still work though …

1 Like

Julia + stdlib and Python + stdlib are not comparable in this instance though, because the former has a far richer numerical programming API and anyone who uses the latter has to rely on a 3rd party library. If we compare Julia + stdlibs to Python + Numpy/Scipy, then my analogy still holds because libraries like JAX have array types which can duck type Numpy arrays. Hence it’s perfectly possible to write code which has no awareness of AD but can easily still be ADed through. It also shows the limitations of trying to make AD work against such a large API surface and number of types: neither JAX or PyTorch’s scipy compatibility layer support some of the fancier types in Numpy (e.g. structured arrays), and operator coverage is <100%.

But I think what you’re getting at is that Julia libraries are more likely to be Julia all the way down, which gives AD more to work with before it bottoms out on an operation it doesn’t understand. Whereas with a library like Numpy, you’re hitting a C or Fortran call much earlier in the stack. This means that one is more likely to hit the Python equivalent of a MethodError than a fallback routine deep inside a stdlib which makes a few too many assumptions about its inputs. I’d reckon the latter category has caused no shortage of headaches for Julia AD maintainers.

7 Likes

Right, which is why I don’t think this satisfies the requirements for a general autodiff system, even though it’s really good for constrained optimization specifically.

A walled garden can’t sustain a herd of cattle, unless it’s massive, but just leaving them in a completely unenclosed space isn’t feasible either, unless you’re willing to constantly look out for wolves. That’s why we graze cattle on large fields surrounded by cheap fences. I’m still looking for the equivalent of the fenced field in Julia autodiff. I’m happy to make some promises (limits on while loops, runtime dispatch, or control flow), but others I won’t (like not using anything other than JuMP). If you can do it in PyTorch, you should be able to do it in Julia; otherwise, it’s fine to exclude it for now.

2 Likes

Agreed, using duck typing it might in principle be possible to write numpy-like code which is ADable. I have never seen any Python library that allows you to use the same code for your model and simply replace the AD backend, i.e., pass in PyTorch or JAX or what not.
Are there any such initiatives for AD agnostic ML frameworks in Python? With AbstractDifferentiation this already works in Julia, e.g., I’m commonly using Flux with ReverseDiff for instance.

While I also think that things that work in PyTorch should be doable in Julia as well, I don’t think it should stop there. I.e., as far as I recall a major reason that PyTorch got ODE solver was that Julia did support AD through ODEs earlier, creating demand in Python as well.

@ToucheSir If you had 5M dollars to spend on solving these issues in Julia, what would you do?

2 Likes

I would spend 60% on overhead to MIT and then have 2 PhD students. Or a solid team of 4 experienced people without coursework to do over the same time period. 5M isn’t so much from an organizational perspective but is enough to keep a project alive.

It’s no shortage of headaches because Zygote made a major misstep in not committing to mutation support. Under the hood almost everything will use mutation in the Julia standard library. Not committing to support mutation effectively means it will be a non-general AD without wrapping every single thing in the standard library. The idea behind it was that functional programming is beautiful and therefore Julia should be a functional programming language and therefore mutation does not need to be supported, but this is the kind of speculative wishful thinking that cornered it. The starting question of any AD in Julia needs to be “how do you support mutation well?” and then work back from there, since otherwise you will never support a “standard” Julia code. This is why Diffractor never stood a chance, it had this missing from the start, and why Enzyme is doing well.

With Jax that essentially never happens. You need to change all control flow to use lax objects, and you have to use pure non-mutating functions. Jax has had time to be adopted and is also in Python, but in the end is really only picking up mindshare in the nerdiest of circles.

And it shouldn’t be surprising given how people talk about it:

https://www.reddit.com/r/MachineLearning/comments/11myoug/d_jax_vs_pytorch_in_2023/

I’ve seen the same thing over and over and over. “Functional programming is better, so therefore all we have to do is make everyone see the light and then when everyone realizes functional programming is the master race, X will be the best”. You can find stackoverflow threads from 2010 espousing the same concept:

If there’s one story that has played over and over in programming languages, it’s exactly this “future” of functional programming. And time and time again, actual developers have thought “that’s cool” and have stuck to programming on multi-paradigm languages where it’s easier to develop.

People keep targeting functional programming because it’s easier for compilers, but compilers don’t adopt a language, developers do.

I think the moral of the story was that the shift in “late Zygote” to drop any mutation support because it knew it couldn’t do it fast, and instead fallback to saying everyone should do functional programming, was living in a dream world for compilers and not a world for humans. You do need to meet people where they are at, even if performance is sometimes not perfectly optimal. PyTorch has gotten a lot of adoption even though its performance is not always optimal.

Automatic differentiation needs to stress the “automatic” before trying to differentiate itself on performance.

18 Likes

For the specific issue being replied to? Not sure, it’s probably more the wheelhouse of those interested in interfaces and more automated verification in Julia.

For AD in general? I’m sure the Enzyme team would appreciate some as mentioned above, but 5M ought to be way more than would be necessary to incrementally improve the existing workhouse ADs. The respective authors are better qualified than I to comment on specifics, but there is a long list of features and sharp edges that haven’t had an opportunity to be addressed. GPU support, SoA/StructArrays support, differentiating wrt arbitrary struct types containing Duals/TrackedReals, better integration with ChainRules (or some other unified rules system) so that more downstream packages can define cross-AD rules, better non-compiled tape performance and more SIMD-ification are all requests I’ve seen at some point but to my knowledge have not been tested for feasibility (let alone implemented).

3 Likes

To be clear, I wasn’t referring specifically to Zygote. One can find plenty of examples of Duals and TrackedReals ending up where they shouldn’t be because of some fallback rabbit hole. Zygote not supporting mutation makes this worse, but it’s something all current Julia ADs have to contend with.

Regardless of how popular JAX itself is, people converting numpy code to PyTorch or TF are also eschewing mutation and data-dependent control flow[1]. The PyTorch team did some very interesting, in-depth analysis of how frequently both appear in a wide range of ML models as part of the development of torch.compile and found they rarely do. For those interested, have a read through some of the TorchDynamo threads on https://dev-discuss.pytorch.org.

That’s not my recollection of the historical context. What I recall is that Zygote’s implementation mutation was(is) so slow and so buggy that it made sense for basically nobody. If you wanted a lot of it, you were better off with ReverseDiff or even ForwardDiff. If you were migrating from Python ML frameworks, you probably weren’t using much mutation already.

It’s also worth nothing that both the PyTorch and TF APIs are becoming more FP-oriented as time goes on. There are a number of reasons for this, but motivation aside I don’t see either suffering for adoption these days as evidenced by the framework adoption figure above. Do I like this direction? Not particularly. Does JAX take it too far? Probably. Is it hurting other frameworks? No evidence of that yet, if anything their adoption is growing.

But since we’re sharing morals for this particular story, here’s one: before replacing existing AD libraries with a new one, make sure the new one can actually replace them[2]. Even now, the majority of people using NNs would be better off using Tracker.jl or Autograd.jl. Those who need control flow and fast scalar support would be better off with ReverseDiff or ForwardDiff. Zygote eventually found its niche in straight-line, mixed scalar and vectorized code, but that’s a smaller niche than even the one JAX occupies. Pushing it as the AD to end all ADs played no small part in the dynamic that @ParadaCarleton is commenting on.


  1. For the uninitiated, this is the only kind of control flow for which you need the lax DSL primitives being referred to. Also worth nothing that ReverseDiff has the exact same limitation when its compiled tape mode is enabled. Given how many Turing models work with a compiled tape, this suggests to me that a lot of Julia code can get away with limited/no data-dependent control flow too. ↩︎

  2. Again, I’m very glad to see this is the approach being taken with Enzyme and to some extent Diffractor. But if there isn’t consensus on acknowledging the failure to do this as a significant past mistake, I fear the chance of it happening again increase. ↩︎

4 Likes

It was only buggy if one did not adopt the safety copy. It’s fine if you safety copy in adjoints the values that can be mutated. So it would’ve just been slow due to extra copies, some saying too slow and therefore dropped it entirely. Julia generally has a sense of “make it work then make it fast”, a development scheme that was not followed in this case.

Indeed, and I think anyone pushing Zygote as the AD to end all ADs was pretty wrong from the start. Even the foundational papers on dP in Julia were pretty clear this would need to be the case. For example let’s pull the classic:

https://dspace.mit.edu/handle/1721.1/137320

See:

Our system can be directly used on existing Julia packages, handling user-defined types, general control flow constructs, and plentiful scalar operations through source-to-source mixed-mode automatic differentiation, composing the reverse-mode capabilities of Zygote.jl and Tracker.jl with ForwardDiff.jl for the “best of both worlds” approach. In this paper we briefly describe how we achieve our goals for a
∂P system and showcase its ability to solve problems which mix machine learning and pre-existing scientific simulation packages.

The core point is that the tenants of what should occur in a differentiable programming language is:

  1. The underlying code should be agnostic to the AD engine it’s used with.
  2. AD engines for dP should directly support packages in the language.

And I think this stamped a good foundation for Julia’s dP ecosystem where things like SciML and Lux are agnostic to what ADs they are used with, and tooling like ChainRules.jl and AbstractDifferentation.jl are direct consequences of this grand scheme. This gives a ripe ecosystem where competition between ADs can readily occur as swapping the backend is possible. This is what we wanted and I believe we have achieved that grand goal quite well.

However, while the landscape for AD competition was set to be ripe, the actual ADs competing on said landscape have ended up lagging for a bit, though Enzyme is really pulling through as of late in the contexts we’ve been testing.

5 Likes

[Enzyme lead here, also caveat I only skimmed this thread, so please call me out on anything].

I think it’s important to note here that Enzyme was created by a few grad students, working part time on the project, to build and learn about AD+Compilers – not explicitly to be a replacement or the end solution for Julia AD.

Now, from a variety of other circumstances which include both Enzyme being well-designed (leading to good performance, mutation support, a community larger than just Julia), as well as many of the other Julia AD systems not receiving a lot of love, many have began to see Enzyme become one of the more critical Julia AD systems. However, I think that looking at its history is useful for folks to understand where it is, where it intends to go, and how to help!

In Julia, we’ve always tried to target a smaller set of code and do that extraordinarily well. This began first as basically GPU compatible code – an incredibly strict subset that disallows most Julia features. Almost all of the issues on Enzyme.jl are feature requests to expand that scope and differentiate through more Julia features. We’ve been adding these, but intentionally limit the subset of code we handle (throwing an error if unsupported), expanding out. The biggest of these recently has been the additiion of custom rules, garbage collection, and a significant part of type unstable code. All of these are Julia-specific which require work outside of Enzyme core on the Julia-specific frontend for Julia-specific features (other languages like C, don’t have GC/type unstable code calls). We still also have plenty of new thing’s we’re looking into including BLAS, scheduling etc.

We are a growing community, but almost all of the Julia side of things have been a few people, working part time on this, as funded by specific means. I for example have often been funded by the DOE, so a lot of my time on Julia work has been to advance those who paid my PhD stipend, which were scientific computing use cases. I’d love to (and think it should be fairly quick) to add good ML support within Julia and Enzyme, but that’s not where my personal funding was coming from, so that hasn’t been a priority for me personally as a result.

That also said, I have now graduated and am starting a new position as a Prof of CS at Illinois, so if anyone does have funds (or wants to apply for a grant together) on compiler based AD, or other related things, please reach out (and we can try to accelerate getting things done)!

61 Likes

Is that because of problem domain though? Classic deep learning vs more quirky/customized sciml stuff?