Flux benchmark being too slow vs Jax

Thanks! Using the technique in ahead-of-time lowering and compilation for jit by froystig · Pull Request #7997 · google/jax · GitHub, I had a look at the compiled output of apply_model. It looks like JAX/XLA is fusing a number of the layernorm operations together. In practice, this should reduce the number of allocations required and thus allocation-related overhead.

The bad news is that, lacking something like XLA, we can’t easily do similar automatic optimizations for Flux models. The silver lining is that, as @mcabbott noted, this overhead should be amortized for larger input and model sizes. If you are working at small scale most of the time, GitHub - PumasAI/SimpleChains.jl: Simple chains may be worth a look.

6 Likes