Python defaults to that (well only option, except e.g. with Numpy, still the default I think), and Numba (trusting your words) but in case you missed it, in comment in code below: “need Float32 because Jax defaults to it” (also Int32).
If you want bit-identical, then of course you should use the same in both (and fma in both or neither; threading for this at least should however be ok), but for faster, Float32 (or smaller) is in general better (I’m though skeptical for this). I’m curious is there any noticeable difference in output with lower precision? Or more likely, how far would you have to zoom in to see it.