Speeding up per-sample gradients?

Yes, I’m aware that map_grad and mean_batch_grad can return essentially the same thing if we sum over the outputs of map_grad.

In the end, I want M gradients with respect to the same parameter but different datapoints. I want to perform operations on the M gradients other than summing / averaging them. For example, I would like to compute the sum of outer products (sum(first(destructure(g))*transpose(first(destructure(g))) for g in map_grad(m, xs)]).

Isn’t the objective in my example already doing what you suggest? Unless you mean that map_grad and mean_batch_grad can be made equally fast by specializing map_grad to compute only the average/sum gradient. Again, this is not what I am after. I just want a way to compute gradients with respect to a parameter but with different individual datapoints in a way that scales better than looping and calculating each gradient independently.

Edit: looking at the post you linked and the other comments, are you saying broadcasting could be faster than mapping? Unfortunately, it is not much faster than map.

m = Chain(Dense(10, 32), Dense(32, 32), Dense(32, 10))
xs = randn(Float32, (10, 256))

function broad_grad(m, xs)
    xs = [xs[:,i] for i = 1:size(xs, 2)]
    Flux.gradient.([m -> sum(m(x)) for x in xs], [m for x in xs])
end

@btime broad_grad(m, xs);
#     2.327 ms (14606 allocations: 2.62 MiB)