Speeding up per-sample gradients?

I am trying to speed up the calculation of per-sample (or per-example) gradients. This can be done via a loop or a jacobian in Julia, but this is slow. The calculation of individual gradients can be batched in several python libraries to speed things up. For example, Jax has vmap and PyTorch has function transforms (see: Per-sample-gradients — PyTorch Tutorials 2.2.0+cu121 documentation).

Is there any way to get a similar speedup in calculating the individual gradients using any Julia package? Preferably something that can work on GPU and a neural network library. Some packages seem related, LoopVectorization and Tullio, but they do not seem obviously capable of what I want.

Below is an example of what I want, map_grad, loop_grad and jac_grad calculate the gradients of each data point. All of these options are slow, scaling linearly in the batch size. PyTorch and Jax can calculate the individual gradients almost as quickly as the average gradient (e.g. in Jax map_grad can be made as quick as mean_batch_grad)

import BenchmarkTools: @btime
using Flux

function mean_batch_grad(m, xs)
    Flux.gradient(m -> sum(m(xs)), m);

function loop_grad(m, xs)
    for i = 1:size(xs, 2)
        _ = Flux.gradient(m -> sum(m(xs[:, i])), m);

function map_grad(m, xs)
    xs = [xs[:,i] for i = 1:size(xs, 2)]
    grad_f = x -> Flux.gradient( m -> sum(m(x)), m)
    map(grad_f, xs)

function jac_grad(m, xs)
    Flux.jacobian( () -> sum(m(xs), dims = 1), Flux.params(m))

m = Chain(Dense(10, 32), Dense(32, 32), Dense(32, 10))
xs = randn(Float32, (10, 256))

@btime mean_batch_grad(m, xs);
@btime loop_grad(m, xs);
@btime map_grad(m, xs);
@btime jac_grad(m, xs);
julia> @btime mean_batch_grad(m, xs);                                                                                                                                                                                    
  76.585 μs (96 allocations: 240.42 KiB)                                                                                                                                                                            
julia> @btime loop_grad(m, xs);                                                                                                                                                                                     
  3.872 ms (24576 allocations: 7.33 MiB)                                                                                                                                                                            
julia> @btime map_grad(m, xs);                                                                                                                                                                                      
  2.966 ms (14602 allocations: 2.61 MiB)                                                                                                                                                                            
julia> @btime jac_grad(m, xs);                                                                                                                                                                                      
  19.069 ms (30589 allocations: 26.05 MiB)                                                
1 Like

So this not my area of expertise and I am unsure what exactly you are asking. But I’ll try to help :slight_smile: You want to have fast per-sample gradients and the way you get them is by batching which is exactly what your benchmarks reflect. So you know already how to do it. So the point of your question is “how do I speed up this batched computation”, right?

IIUC, the “function transforms” necessary in Python is pretty much equivalent to what we can do in Julia by just putting in an Array with another dimension. So I think this route of optimization is exhausted.

You suggest GPU, but did not test it so far. Are you unsure how to do it? In theory it should be easy. I think in essence you just install a suitable GPU package and then convert your xs to the corresponding array type. See Flux documentation here


Let me clarify, I want the gradient with respect to each individual data point (which is what loop_grad, map_grad and jac_grad calculates). I do not want the average over the data points (which is what mean_batch_grad calculates). PyTorch and Jax can calculate the individual gradients almost as quickly as calculating the average by using vmap over the batch dimension. I was wondering if there is anything in Julia that can do the same.

Ive tested the functions on GPU and map, loop and jacobians are all much slower.

1 Like

IIUC you are talking about some kind of kernel fusion, i.e. an optimization step that optimizes away repeated computations within the batch. I don‘t think that this will happen with Zygote.

Did you try using a map over the batch for the forward pass and then differentiating that?

IIUC Enzyme should make this easier with batched reverse mode.

1 Like

I’m not familiar enough with the internals of AD or Zygote to know for sure, but that does sound like what I am looking for.

Unfortunately mapping over the batch will require a call to Flux.jacobian which is the slowest of all the methods I have tried.

I expect that next-gen AD, like Enzyme or Diffractor, will be capable of this. And, I would be happy to use them now if it is already possible.

I guess if you don‘t need the whole batch of network gradients to be output at once, you could use a single shadow in Enzyme and overwrite it for the different samples. But this depends on what you need the gradients for. But if you‘re gonna change the weights in-between samples what you seem to be describing wouldn‘t work anyway.

So Jax.vmap does what you want in a single backward pass and gives you n_batch different model gradients at the same complexity of one?

There are a few use-cases where this is needed; I am using it for an empirical fisher approximation to the Hessian, but there are several other uses for it too.

I am not sure about the specifics of Jax.vmap, as I have been primarily using Julia/Flux the past 2-3 years. I have not used Jax or PyTorch extensively, and I haven’t benchmarked the complexity myself. Looking at the PyTorch tutorial on per-sample grads (benchmark at the bottom), using vmap is about 10x faster than not using vmap in calculating 64 gradients of a loss with respect to a fixed weight but different datapoints.

Could it be that we are talking just about parallelization here? If yes, broadcasting the per-sample gradient kernel over a GPU array could speed things up.

Mostly. Just broadcasting won’t work as well as vmap, however, because some of the operations being broadcasted are already vectorized (e.g. BLAS). vmap will actually modify those calls (using dispatch in PyTorch and source code transforms in JAX) to use batched implementations whenever it encounters them.

Now that said, some functions and Flux layers are flexible enough to already work for this without a vmap-like treatment. See this topic posted about a month ago: Flux loss with contribution gradient is slow - #5 by Jonas208. Basically, changing your loss function to compute a loss for each sample individually and then summing should be enough for a MLP. In fact, mean_batch_grad and map_grad currently return the exact same gradients because sum(map(x -> sum(model(xs)), xs)) == sum(model(xs))!

Yes, I’m aware that map_grad and mean_batch_grad can return essentially the same thing if we sum over the outputs of map_grad.

In the end, I want M gradients with respect to the same parameter but different datapoints. I want to perform operations on the M gradients other than summing / averaging them. For example, I would like to compute the sum of outer products (sum(first(destructure(g))*transpose(first(destructure(g))) for g in map_grad(m, xs)]).

Isn’t the objective in my example already doing what you suggest? Unless you mean that map_grad and mean_batch_grad can be made equally fast by specializing map_grad to compute only the average/sum gradient. Again, this is not what I am after. I just want a way to compute gradients with respect to a parameter but with different individual datapoints in a way that scales better than looping and calculating each gradient independently.

Edit: looking at the post you linked and the other comments, are you saying broadcasting could be faster than mapping? Unfortunately, it is not much faster than map.

m = Chain(Dense(10, 32), Dense(32, 32), Dense(32, 10))
xs = randn(Float32, (10, 256))

function broad_grad(m, xs)
    xs = [xs[:,i] for i = 1:size(xs, 2)]
    Flux.gradient.([m -> sum(m(x)) for x in xs], [m for x in xs])

@btime broad_grad(m, xs);
#     2.327 ms (14606 allocations: 2.62 MiB)                     

Ok, it wasn’t clear what you wanted to do with the per-sample grads afterwards but now I see you want exactly what is described on that PyTorch docs page.

The first optimization would be to save recomputing the forward pass across different samples. This can be done using a similar trick as the one Zygote.jacobian exploits:

function map_grad2(m, xs)
    y, back = Zygote.pullback(m -> vec(sum(m(xs); dims=1)), m)
    eye = I(length(y))
    return [back(seed) for seed in eachcol(eye)]

Notably, this allows us to make use of vectorization in the forward pass. For example, matrix multiplications can still run over the full batch. The result is about 7x faster than loop_grad and 8x slower than mean_batch_grad:

julia> @btime mean_batch_grad($m, $xs);
  85.567 μs (93 allocations: 240.23 KiB)

julia> @btime map_grad2($m, $xs);
  641.063 μs (497 allocations: 1.05 MiB)

julia> @btime loop_grad($m, $xs);
  4.589 ms (24064 allocations: 7.30 MiB)

What vmap in PyTorch and JAX do on top of this is to also vectorize the loop in the backwards pass. This is a lot easier for the Python ML libraries since they only have to deal with a finite number of operations, but trickier in Julia land because the operation space is basically unbounded. There have been some attempts in the past to replicate either the PyTorch or JAX implementation using dispatch or compiler transforms respectively: see this old discussion, this ancient package or the more recent GitHub - torfjelde/Batching.jl (cc @torfjelde). My understanding is that most attempts fizzle out trying to boil the ocean of possible Julia functions.

1 Like

Interesting, that helps a lot. Thank you!

If you could compile back to a GPU kernel and then broadcast it, how different would it be to Jax.vmap?

I would say no. The JAX equivalent would be something like:

def map_grad2(m, xs):
    y, back = jax.vjp(lambda m: m(xs).sum(axis=1))
    eye = jnp.eye(len(y))
    return [back((seed,)) for seed in eye]

Compiling to GPU kernels (you’d need more than one here) would be similar to @jax.jiting map_grad2. It does make each pullback call slightly faster, but you’re still missing out on the parallelism and work sharing vmap can enable.

I only just got around to trying your proposed solution. My benchmark shows map_grad2 is actually slower. Am I missing something, benchmarking wrong or do we have version differences? I am on Julia 1.10, Flux 0.14.9 and Zygote 0.6.68.

import BenchmarkTools: @btime
using Flux
import Flux.Zygote
import LinearAlgebra

function map_grad(m, xs)
    xs = [xs[:,i] for i = 1:size(xs, 2)]
    grad_f = x -> Flux.gradient( m -> sum(m(x)), m)
    gs = map(grad_f, xs)
    [g for g in gs]

function map_grad2(m, xs)
    y, back = Zygote.pullback(m -> vec(sum(m(xs); dims=1)), m)
    eye = LinearAlgebra.I(length(y))
    return [back(seed) for seed in eachcol(eye)]

m = Chain(Dense(10, 32), Dense(32, 32), Dense(32, 10))# |> Flux.gpu
xs = randn(Float32, (10, 256))# |> Flux.gpu

@btime map_grad($m, $xs);
@btime map_grad2($m, $xs);
julia> @btime map_grad($m, $xs);                                                                                                                                                                                    
  2.551 ms (14604 allocations: 2.62 MiB)                                                                                                                                                                            
julia> @btime map_grad2($m, $xs);                                                                                                                                                                                   
  13.419 ms (11563 allocations: 23.34 MiB)    

Unfortunately, I had the wrong dim in the original example. After editing to use the batch dim, the performance is not as good and about what you see.

This goes to show that being able to vectorize the backwards pass as well is the most important part for performance. To illustrate, here’s what jax.vmap rewrites a grad call to for a single Dense layer:

In [71]: jax.make_jaxpr(jax.vmap(jax.grad(loss), in_axes=(None, 0)))(variables, xs)
{ lambda ; a:f32[8] b:f32[8,8] c:f32[256,8]. let
    d:f32[256,8] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c b
    e:f32[1,8] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 8)] a
    f:f32[256,8] = add d e
    _:f32[256] = reduce_sum[axes=(1,)] f
    g:f32[8] = broadcast_in_dim[broadcast_dimensions=() shape=(8,)] 1.0
    h:f32[8,256,8] = dot_general[dimension_numbers=(([], []), ([], []))] g c
    i:f32[256,8,8] = transpose[permutation=(1, 2, 0)] h
    j:f32[256,8] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(256, 8)] g
  in (j, i) }

Now you could write these operations out by hand in Julia code. There are even packages which provide batched functions like NNlib: Reference · NNlib.jl. For interop with existing code, the problem of needing either a dispatch (overloading) or code-rewriting based solution like the Python ML libraries have remains. Enzyme.jl may be an option as @simsurace mentioned, but I’m not sufficiently familiar with how their gradient batching works to help here. If that’s a route you’re interested in, I’d try the usual Enzyme help channels.

That’s too bad, I appreciate your help though! I am looking forward to Enzyme, but it does not seem ready for regular use yet. I guess I will have to use Jax for this project.

1 Like