Zygote vs Jax performance example

If high-order derivatives of scalar functions is what you care about, then you might want this:

julia> using TaylorSeries

julia> @btime tanh(1 + Taylor1(3)).coeffs[4] * 6
  549.156 ns (10 allocations: 944 bytes)
0.6216266807712962

or better, in fact, this:

julia> using ForwardDiff

julia> DF(f) = x -> ForwardDiff.derivative(f,x);

julia> @btime DF(DF(DF(tanh)))(1.0)
  139.202 ns (3 allocations: 80 bytes)
0.6216266807712962

If not, and you care about very different problems, then I wouldn’t conclude too much from this comparison.

11 Likes