I have seen that Flux.jl does not use CuDNN for RNNs. Does this have a specific reason?
Going through the issues is really confusing, it seems that support for this was removed because tests would fail. There are also one or two PRs that reimplement it. Maybe someone can clear up my confusion.
One additional thing I’ll note is that PyTorch + TF had to bend over backwards to make cuDNN RNN interop work. If you’ve ever noticed all the caveats and info boxes about using their RNNs on GPU, this is why. In contrast, most JAX frameworks just roll their own and I haven’t seen reports of poor performance (though they can lean heavily on XLA to help with that).