I am training a neural network using a manual custom gradient update.
I have a loop that goes through every example and computes the gradient of the loss on that example. I then use that gradient in my custom update. This loop is much slower than evaluating the batch gradient on all the examples at once even though presumably the batch also involves evaluating all the individual gradients. Example code is below.
import Flux
import Random
function main()
L = 10
net = Flux.Chain(Flux.Dense(L, L, Flux.relu),
Flux.Dense(L, 1))
N_examples = 1000
examples = Random.randn(Float64, L, N_examples)
labels = Random.randn(Float64, 1, N_examples)
loss(x, y) = sum((net(x) .- y).^2) / length(y)
params = Flux.params(net)
println("Batch gradient is fast")
@time grad = Flux.Tracker.gradient(() -> loss(examples, labels))
println("Gradient of each example is slow")
@time for i in 1:N_examples
grad = Flux.Tracker.gradient(() -> loss(examples[:, i], labels[i]))
end
end
main()
Example output:
Batch gradient is fast
0.006367 seconds (632 allocations: 531.516 KiB)
Gradient of each example is slow
0.166312 seconds (505.22 k allocations: 15.944 MiB, 3.84% gc time)
Is there any way to make my loop as fast as the batch evaluation?