Zygote vs Jax performance example

Hi, it’s Saturday and I somehow had a Jax tab open in my browser, so I decided to compare it against Zygote using their third derivative of tanh example:

# jax cpu v0.1.57
from jax.numpy import tanh
from jax import grad, jit
foo = grad(grad(grad(tanh)))  # runs instantly
foo_jit = jit(foo)  # runs instantly
%timeit foo(1.)  # 6.8ms
%timeit foo_jit(1.)  # 40us

Now Zygote:

# julia v1.5
using Zygote #v0.5.14
using BenchmarkTools
D(f) = x -> gradient(f,x)[1]
D(f,n) = n>1 ? D(D(f),n-1) : D(f)
g = D(tanh,3)
g(1.) # compiles a minute...
@benchmark $g(1.) # mean 1.37ms

So Zygote is 5 times faster than standard jax, but jitted jax is 34 times faster. Am I doing something wrong? Is there some package which gets this to jax speed (other than my pencil)? I know that there’s a PR for speeding up tanh, but for other functions the picture is similar. Also, is there a way to reduce the Zygote compile time in such a case? Btw: 1) The analytical solution takes 80ns or so. 2) The jax result is a little off from the true value, as it only uses 32bit. 3) In julia I of course have the choice of type and can get high accuracy :> 4) This is of course no comprehensive performance comparison and rather a point-measurement.

It looks like you are doing the benchmarking in global scope with non-const globals, which can significantly harm performance.

Not sure if this fixes it, but can you try

g = D(tanh, 3)  # tanh is a const, but f isn't. 
@benchmark $g(1.0) # interpolate g

?

It also looks a bit unfair that you hardcode grad(grad(grad(tanh))), but give a general recursive definition to Julia (but perhaps the compiler unrolls this, I’m not sure).

So 1.6 makes tanh about 3x faster, but one place where this could still improve is that we don’t have a good inaccurate version of tanh yet. Allowing 3 Ulps of inaccuracy would probably give a further 2x speedup or so.

Edit: turns out that the way we defined sech means that my pr that speeds up coshalso speeds up sech, so 1.6 should be much faster here.

2 Likes

If high-order derivatives of scalar functions is what you care about, then you might want this:

julia> using TaylorSeries

julia> @btime tanh(1 + Taylor1(3)).coeffs[4] * 6
  549.156 ns (10 allocations: 944 bytes)
0.6216266807712962

or better, in fact, this:

julia> using ForwardDiff

julia> DF(f) = x -> ForwardDiff.derivative(f,x);

julia> @btime DF(DF(DF(tanh)))(1.0)
  139.202 ns (3 allocations: 80 bytes)
0.6216266807712962

If not, and you care about very different problems, then I wouldn’t conclude too much from this comparison.

11 Likes

The upcoming Diffractor.jl will do substantially better than Zygote for nested derivatives, as well as on any problems where constant overhead is a problem.

4 Likes

And anything with type inference issues, which is the first problem that the new compiler infrastructure solves over Cassette/IRTools

3 Likes

Any hope for Diffractor on 1.6 (even if it isn’t perfect, at least can start using the frontend interface)?

I knew you+julia wouldn’t let me down :smiley: thank you!
@mcabbott’s answer works super fast and is what I was looking for. Somehow I can’t mark it as solution.
I am looking forward to 1.6 and the improvements there (thanks Oscar). And after hearing all the good news about Diffractor I can’t wait to see it in action! By the way: benchmarking in global/local scope didn’t make any difference in this case, so I posted the former.

2 Likes

This PR won’t merge until after 1.6 branches, so I don’t think so.
Using opaque closures will also help Zygote’s type inference problems.

I’ll point out @Keno’s SPLASH talk on Diffractor:

And also Matt Bauman’s talk on applications:

We’re already able to do a lot even with suboptimal tools today. I really think that once the new stuff lands, we’ll got to a new level.

13 Likes

Great thread.
Slightly off topic: where can I find the code for Diffractor.jl? Googling doesn’t lead me anywhere.
Thanks again!

1 Like

There isn’t a repo for it yet. Various pieces are strewn across a few publix repos, bit I haven’t yet drawn it all together. My plan was to do that once OpaqueClosure is. In, so people can actually try it.

2 Likes

Ok! Thanks for the answer!

Maybe, I don’t exactly know how it works, you should use foo_jit(1.).block_until_ready() to get an accurate benchmark. Can you test this and report whether it changes anything?

1 Like

That didn’t change the timings.