Just-in-time gradient optimization idea

This is exactly what happens if you are using Lux + Reactant with Utilities | Lux.jl Docs and the return_gradients in single_train_step! is set to Val(false).

Essentially Reactant/XLA will automatically analyze the lifetime of an intermediate buffer and reuse the memory for operations if possible. This is one of the many reasons why we outperform (Lux.jl/perf at main · LuxDL/Lux.jl · GitHub) CUDA.jl baselines.

maybe not completly indeed but I have no idea where those allocs could come from

We need to allocate buffers to hold partial derivatives from the intermediate ops. Unless you cache all the intermediate allocations, autodiff call will always allocate.

5 Likes