I want to say Jax (python in general) installation experience is horrible, pip Jax tells me to also install jaxlib, jaxlib somehow downloads tensorflow runtime and its own LLVM and start compiling using bazel for 10 minutes. Just for me to run your code and verify my hypothesis…
again, I want to point out that:
In [2]: a = np.arange(10)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [3]: a
Out[3]: DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
almost all python’s “ML” backend falls to 32-bit number by default: