Gradient loop vs batch gradient speed

I am training a neural network using a manual custom gradient update.

I have a loop that goes through every example and computes the gradient of the loss on that example. I then use that gradient in my custom update. This loop is much slower than evaluating the batch gradient on all the examples at once even though presumably the batch also involves evaluating all the individual gradients. Example code is below.

import Flux
import Random

function main()
    L = 10
    net = Flux.Chain(Flux.Dense(L, L, Flux.relu),
                     Flux.Dense(L, 1))

    N_examples = 1000
    examples = Random.randn(Float64, L, N_examples)
    labels = Random.randn(Float64, 1, N_examples)

    loss(x, y) = sum((net(x) .- y).^2) / length(y)

    params = Flux.params(net)
    println("Batch gradient is fast")
    @time grad = Flux.Tracker.gradient(() -> loss(examples, labels))

    println("Gradient of each example is slow")
    @time for i in 1:N_examples
        grad = Flux.Tracker.gradient(() -> loss(examples[:, i], labels[i]))


Example output:

Batch gradient is fast
0.006367 seconds (632 allocations: 531.516 KiB)
Gradient of each example is slow
0.166312 seconds (505.22 k allocations: 15.944 MiB, 3.84% gc time)

Is there any way to make my loop as fast as the batch evaluation?

You can land somewhere in between by using mini batches

You can probably not reach the speed of the full batch version since that can exploit very optimized matrix multiply and simd to a greater extent

I can’t use mini-batches for what I’m doing. I need the individual gradients. Do you have any suggestions for speeding up the loop itself?

You can try to reduce the number of allocations due to slicing, but you’ll never escape the overhead of the AD tracker. You can also try using Zygote.jl instead of tracker, check out Flux#zygote branch