Ok, it wasn’t clear what you wanted to do with the per-sample grads afterwards but now I see you want exactly what is described on that PyTorch docs page.
The first optimization would be to save recomputing the forward pass across different samples. This can be done using a similar trick as the one Zygote.jacobian
exploits:
function map_grad2(m, xs)
y, back = Zygote.pullback(m -> vec(sum(m(xs); dims=1)), m)
eye = I(length(y))
return [back(seed) for seed in eachcol(eye)]
end
Notably, this allows us to make use of vectorization in the forward pass. For example, matrix multiplications can still run over the full batch. The result is about 7x faster than loop_grad
and 8x slower than mean_batch_grad
:
julia> @btime mean_batch_grad($m, $xs);
85.567 μs (93 allocations: 240.23 KiB)
julia> @btime map_grad2($m, $xs);
641.063 μs (497 allocations: 1.05 MiB)
julia> @btime loop_grad($m, $xs);
4.589 ms (24064 allocations: 7.30 MiB)
What vmap
in PyTorch and JAX do on top of this is to also vectorize the loop in the backwards pass. This is a lot easier for the Python ML libraries since they only have to deal with a finite number of operations, but trickier in Julia land because the operation space is basically unbounded. There have been some attempts in the past to replicate either the PyTorch or JAX implementation using dispatch or compiler transforms respectively: see this old discussion, this ancient package or the more recent GitHub - torfjelde/Batching.jl (cc @torfjelde). My understanding is that most attempts fizzle out trying to boil the ocean of possible Julia functions.