There are couple of things at play here:
- Flux and PyTorch are equivalent in that both frameworks execute all non-accelerated (e.g. not BLAS, NNPACK, CuDNN) code single-threaded and shell out to ops that are usually multi-threaded. There can be gains from running non-BLAS Flux code on multiple threads/processes, but that should (in theory) not be required to match a normal PyTorch training loop and isn’t even the same programming model (i.e. the PyTorch equivalent would require a GIL-less Python).
- Current Flux incurs a large time penalty for the TTFG (time to first gradient) on both the forward and backwards pass. Unless you’re only calling gradient once per REPL session/script, the recommended way to benchmark is to “pre-warming” the forward pass by inserting a
model(<dummy data>)
call before the training loop you want to benchmark. This won’t reduce absolute execution times, but it will help isolate the parts you can control. -
exp
and derived functions liketanh
are the biggest bottleneck for Flux on CPU right now. IIRC the performance delta is under 10% for a normal dense relu network. I’d have a look at some of the “fast exp” options mentioned in Flux vs pytorch cpu performance - #50 by Elrod and see if those work for you. Also worth a run through a profiler to see what other bottlenecks may be present. - Flux and PyTorch performance are far closer on GPU. If you intend to train and experiment on GPU anyhow, it’s worth testing there as well.
- Are both Flux and PyTorch using Float64? If the latter is on Float32 by default, that could also be a performance advantage.
I don’t want to sugar-coat things: Flux does not currently benchmark well on CPU for “relatively linear” (i.e. non-SciML) models, and the “Zygote era” has certainly not helped. On the bright side, here are a couple improvements you can look forward to in the not-so-distant future:
- A new-and-improved AD system that should dramatically improve upon Zygote’s performance (the biggest piece of TTFG).
- CPU-specific tuning in Flux and NNlib. For example, https://github.com/FluxML/NNlib.jl/pull/191 (mentioned in your previous thread) has since landed in an NNlib release.