As luck would have it, there was a thread just the other day on this: Flux - LSTM - Issue with input format for multiple features - #8 by Mateusz_K.
One additional thing I’ll note is that PyTorch + TF had to bend over backwards to make cuDNN RNN interop work. If you’ve ever noticed all the caveats and info boxes about using their RNNs on GPU, this is why. In contrast, most JAX frameworks just roll their own and I haven’t seen reports of poor performance (though they can lean heavily on XLA to help with that).