Hi, it’s Saturday and I somehow had a Jax tab open in my browser, so I decided to compare it against Zygote using their third derivative of tanh example:
# jax cpu v0.1.57
from jax.numpy import tanh
from jax import grad, jit
foo = grad(grad(grad(tanh))) # runs instantly
foo_jit = jit(foo) # runs instantly
%timeit foo(1.) # 6.8ms
%timeit foo_jit(1.) # 40us
Now Zygote:
# julia v1.5
using Zygote #v0.5.14
using BenchmarkTools
D(f) = x -> gradient(f,x)[1]
D(f,n) = n>1 ? D(D(f),n-1) : D(f)
g = D(tanh,3)
g(1.) # compiles a minute...
@benchmark $g(1.) # mean 1.37ms
So Zygote is 5 times faster than standard jax, but jitted jax is 34 times faster. Am I doing something wrong? Is there some package which gets this to jax speed (other than my pencil)? I know that there’s a PR for speeding up tanh, but for other functions the picture is similar. Also, is there a way to reduce the Zygote compile time in such a case? Btw: 1) The analytical solution takes 80ns or so. 2) The jax result is a little off from the true value, as it only uses 32bit. 3) In julia I of course have the choice of type and can get high accuracy :> 4) This is of course no comprehensive performance comparison and rather a point-measurement.
It looks like you are doing the benchmarking in global scope with non-const globals, which can significantly harm performance.
Not sure if this fixes it, but can you try
g = D(tanh, 3) # tanh is a const, but f isn't.
@benchmark $g(1.0) # interpolate g
?
It also looks a bit unfair that you hardcode grad(grad(grad(tanh))), but give a general recursive definition to Julia (but perhaps the compiler unrolls this, I’m not sure).
So 1.6 makes tanh about 3x faster, but one place where this could still improve is that we don’t have a good inaccurate version of tanh yet. Allowing 3 Ulps of inaccuracy would probably give a further 2x speedup or so.
Edit: turns out that the way we defined sech means that my pr that speeds up coshalso speeds up sech, so 1.6 should be much faster here.
The upcoming Diffractor.jl will do substantially better than Zygote for nested derivatives, as well as on any problems where constant overhead is a problem.
I knew you+julia wouldn’t let me down thank you! @mcabbott’s answer works super fast and is what I was looking for. Somehow I can’t mark it as solution.
I am looking forward to 1.6 and the improvements there (thanks Oscar). And after hearing all the good news about Diffractor I can’t wait to see it in action! By the way: benchmarking in global/local scope didn’t make any difference in this case, so I posted the former.
There isn’t a repo for it yet. Various pieces are strewn across a few publix repos, bit I haven’t yet drawn it all together. My plan was to do that once OpaqueClosure is. In, so people can actually try it.
Maybe, I don’t exactly know how it works, you should use foo_jit(1.).block_until_ready() to get an accurate benchmark. Can you test this and report whether it changes anything?