Attempted to port an RL algorithm from PyTorch to Flux and it's 10x slower

One thing is you aren’t very consistent with using Float32 vs Float64, like:

https://github.com/DevJac/gdrl-with-flux/blob/master/Ch8.jl#L50

are all Float64 constants. Doing it all as Float32, i.e. 1f0 is probably a good idea. With all of this switching around, you might be going backwards through the network in Float64 which would be one hit. Other than that, the issue might be your BLAS. Have you tried it with MKL.jl to see if it’s an OpenBLAS vs MKL issue?

2 Likes