Any faster way of computing small gradients?

Isn’t Jax using multithreading by default? If so, compare with that in Julia as well. What data type is Jax using? Could they be defaulting to Float32?

See, e.g thus thread for similar benchmarks