I am trying to speed up the calculation of per-sample (or per-example) gradients. This can be done via a loop or a jacobian in Julia, but this is slow. The calculation of individual gradients can be batched in several python libraries to speed things up. For example, Jax has vmap and PyTorch has function transforms (see: Per-sample-gradients — PyTorch Tutorials 2.2.0+cu121 documentation).
Is there any way to get a similar speedup in calculating the individual gradients using any Julia package? Preferably something that can work on GPU and a neural network library. Some packages seem related, LoopVectorization and Tullio, but they do not seem obviously capable of what I want.
Below is an example of what I want, map_grad, loop_grad and jac_grad calculate the gradients of each data point. All of these options are slow, scaling linearly in the batch size. PyTorch and Jax can calculate the individual gradients almost as quickly as the average gradient (e.g. in Jax map_grad can be made as quick as mean_batch_grad)
import BenchmarkTools: @btime
using Flux
function mean_batch_grad(m, xs)
Flux.gradient(m -> sum(m(xs)), m);
end
function loop_grad(m, xs)
for i = 1:size(xs, 2)
_ = Flux.gradient(m -> sum(m(xs[:, i])), m);
end
end
function map_grad(m, xs)
xs = [xs[:,i] for i = 1:size(xs, 2)]
grad_f = x -> Flux.gradient( m -> sum(m(x)), m)
map(grad_f, xs)
end
function jac_grad(m, xs)
Flux.jacobian( () -> sum(m(xs), dims = 1), Flux.params(m))
end
m = Chain(Dense(10, 32), Dense(32, 32), Dense(32, 10))
xs = randn(Float32, (10, 256))
@btime mean_batch_grad(m, xs);
@btime loop_grad(m, xs);
@btime map_grad(m, xs);
@btime jac_grad(m, xs);
julia> @btime mean_batch_grad(m, xs);
76.585 μs (96 allocations: 240.42 KiB)
julia> @btime loop_grad(m, xs);
3.872 ms (24576 allocations: 7.33 MiB)
julia> @btime map_grad(m, xs);
2.966 ms (14602 allocations: 2.61 MiB)
julia> @btime jac_grad(m, xs);
19.069 ms (30589 allocations: 26.05 MiB)