Yeah the transition to 1.11 has been a big undertaking for Enzyme due to the introduction of Memory
and rebasing Array
on top of it. You should definitely submit a github issue with this example.
But for the love of god don’t submit any example involving DI. Craft a pure-Enzyme MWE
Yes I did that twice I think, I will avoid it promise
Confirming that you get a runtime activity error on 1.11, and the workarounds are to either make weights_ctx
Duplicated as above, or change the mode to Enzyme.set_runtime_activity(Enzyme.Reverse)
. The latter seems to be a hair faster.
I made the mwe
using Enzyme
function foo(x,y)
y2 = reshape(y,2,5)
return sum(x .+ y2[:])
end
x = rand(10)
y = rand(10)
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Enzyme.Reverse, foo, Enzyme.Duplicated(x,dx), Enzyme.Const(y))
dx
I file the issue, it seems like it’s not only about reshape, y2 = y.^2 does the same. idea for the name of this issue just issue on 1.11 ? maybe Enzyme on 1.11 needs more Cache than on 1.10 ?
Looks like it’s already reported here, at least for reshape: error when differentiating `reshape` · Issue #2214 · EnzymeAD/Enzyme.jl · GitHub
The y.^2
issue hits a different codepath (broadcasting machinery, not GenericMemory
), so I suppose it might warrant a separate issue.
ok it’s the one i used anyway
I think this shows how good JAX is and how Julia isn’t there yet, unfortunately. Instead of time-to-first-plot I’m now struggling with time-to-first-gradient (TTFG). That’s in a language specifically geared towards data science, ML and statistics, where gradients are first-class citizens.
The smallest TTFG I got (`Zygote.gradient` is 54000 TIMES slower than `jax.gradient` - #29 by ForceBru) was 4 seconds, whereas the equivalent JAX time (first JITted gradient with jax.block_until_ready
, so compilation time is included, like in the Julia version) is 10 times less, at 0.44 seconds.
Julia times after compilation are mostly good: Mooncake consistently delivers 4 ms, while in JAX I get around 2.89 ms (mean of 1000 runs using timeit.timeit
).
In JAX, I wrote the first code that came to mind and it was fast straight away. In Julia, I encountered a massive (literally 54 thousand times slower!) performance hit and had to seek help from autodiff gurus who can of course optimize the heck out of everything.
Enzyme got 0.9ms at second gradient so I think we’re definetly there, however fo r TTFG you’re right, but who cares only doing that one time and waiting 1s instead of 0.4
This is the age-old debate on compilation time. When you compute a million gradients, it doesn’t make a difference whether the first one is fast or just okay. But of course it being unreasonably slow is not great for user experience.
if you wanna add any info
I care, because irl I have 1 million 128-dimensional parameters and half a million data points (compared to half a million 2D params here and 100 datapoints). With my original code, I waited more than an hour for the first gradient to compute, rage quit, spent hours factoring out the part responsible for computing the gradient, debugging it, writing the MWE here, writing the JAX MWE (okay, that took about a minute), trying out various things with the Julia code etc.
Note that I’m doing all this instead of doing my actual job (estimating the model). Sure, getting help from the experts here is great and very valuable, so I’m not wasting my time here, but I’d also very much like to just get the job done without debugging TTFG issues.
Now I get the good timings with the 2D parameters, but my next task is to see how well it’ll work with the 128-D params. JAX runs in 181 ms averaged across 1000 runs and automatically uses multithreading.
I don’t know if compilation depends on the size of the input, it seems weird, you should have the same overhead with small and big size arrays, so I guess this used Zygote first exemple which indeed would lead to big differences. About julia being really hard to debug, you’re right, that’s what DI tries so hard to make easier (still a lot easier than c++ ). You should use Reactant.jl for sure then if you want paralel code without doing anything kernel-wize
Weird, yes, but that’s what I got. I may make a table to show how TTFG depends on the number of dimensions. Currently loss_mcabbott_mean
doesn’t seem to depend on it, neither for Zygote, nor for Mooncake.
The only moments when TTFG would depend on that is if types are changing since there will be a lot more numbers that change types (fully guessing don’t hesitate if im wrong). If you make a tabular do (size | TTFG-TTG) please to see this
There are a few related things here to pick apart.
Firstly, jax.jit
seems to have a much lower fixed compilation overhead than most Julia ADs. However, I’ve heard enough reports of people favouring Julia ADs despite the long TTFG because it was faster than JAX for their large models. Either way, TTFG is a big problem and we need better solutions for it.
Secondly, there’s a bit of a technical mixed with a philosophical problem here. If you asked someone to write the original loss function in pure Julia, they probably would’ve written something like what @yolhan_mannes did: minimal allocation and very loopy. Not surprisingly, this does well with certain ADs. If you asked someone to write the same function like they were a Python programmer, they probably would’ve written something like @mcabbott’s best-performing examples: fully vectorized operations, minimal looping and exploiting vectorized operator fusion where possible. That does well with other ADs. What does not work well is doing something in between. For example, looping over slices of an input array and apply vectorizing (i.e. expensive and allocating) operations to each.
So why do we see this kind of code again and again in the real world? One factor is that most people who are familiar with AD in Python and Julia are generally more comfortable with the former than the latter. This means that code examples are more likely to be in the “in between” style which is a worst case for Julia ADs. This IMO is an education and documentation problem: we need to direct people towards either writing more Pythonic vectorized code or more Julian scalarized code depending on their use case.
But the other side is that the Python AD libraries provide a better “pit of success” for users trying to write idiomatic code. For example, the JAX example in the OP uses vmap
, while @mcabbott’s examples had to do some/all of that vectorization by hand. Granted, I think there’s still a cultural/familiarity aspect in that a generator comprehension over array slices would be immediately flagged as a performance problem by any proficient JAX/PyTorch/Numpy user, but this challenge of making the fast path the obvious one has been an evergreen once ever since I started following the Julia AD ecosystem.
I don’t know your background, but it’s worth keeping in mind that “the first code that came to mind” may look very different depending on your prior experience. If you’ve used JAX a lot you’re likely to gravitate towards idioms that work well with JAX, and the same is true for Julia. Hence, what looks like “optimizing the heck out of everything” to one user may just be “the first code that came to mind” to another.
That said, your perspective is appreciated. Hours of compilation time is certainly not a good look for Julia.
(Then there’s the standard list of excuses for Julia which you may or may not find compelling: Julia AD packages are solving a harder problem, trying to differentiate the entire language, while python-based frameworks like JAX constrain you to their limited DSL; Julia AD packages are developed by solo researchers while JAX is developed by Google; et cetera. I suppose they may afford Julia some sympathy points, but what does that matter if you were able to solve your problem in JAX and not in Julia.)
I agree, the ability to do bad things in julia is really high and there will be a moment where it won’t be possible to make the language better at correcting those and then we will have to choose between “let people do bad and tell them why as much as possible” and "make julia 2.0 a lot less permisive and only allow well written code (no idea what this means btw) ". We say julia is really close to python but someone comming from fortran will make 1000x faster code than someone coming from python/matlab.
Why didn’t you include the Enzyme result?
I can’t get it to work. Keep getting errors about its inability to prove that the function doesn’t modify its arguments.