Automatic gradient ∼10x slower to evaluate than the primal computation

I need to evaluate a basic dense and a few-layer neural network with up to tens of thousands of inputs at once. The outputs will be then forwarded to another function, which returns a scalar. However, the performance of the automatic gradient is lacking even in the following simple example model mapall, which merely returns the squared sum of the neural network outputs:

using Random
using Flux
using ChainRulesCore
using Random
using LinearAlgebra
using BenchmarkTools

struct  NeuralNet{F, W<:AbstractMatrix, B} 

Flux.@functor NeuralNet
Flux.trainable(a::NeuralNet) = (weight = a.weight, bias = a.bias,)

function (a::NeuralNet)(x::AbstractVecOrMat)
    W, b, σ = a.weight, a.bias, a.σ
    return σ.(W*x .+ b)

function mapall(nn,x)
  return sum( (map(t->nn([t,])[1],x)) .^ 2.0)


ra = range(-1,1,10000)

N1 = 1000
W1 = randn(N1,1)
b1 = randn(N1)
l1 = NeuralNet(W1,b1,identity)

N2 = 20
W2 = randn(N2,N1)
b2 = randn(N2)
l2 = NeuralNet(W2,b2,identity)

W3 = randn(1,N2)
b3 = randn(1)
l3 = NeuralNet(W3,b3,identity)

model = Chain(l1,l2,l3)

@btime mapall(model,ra)
@btime Flux.gradient(m->mapall(m,ra),model)

Evaluation of mapall takes approximately 90 milliseconds (which is very good), whereas its gradient with respect to biases and weights of the neural network layers takes over 1.1 seconds. For instance,

@btime model([ra[1]])

takes about 8 microseconds to evaluate, while
@btime Flux.gradient(m->m([ra[1]])[1],model)

takes approximately 36 microseconds, only four times more than the forward pass. Is there any hope to improve the gradient calculation performance from 10x slower evaluation (compared to the primal) closer to the 4x or even better? I have understood the gradient is usually 2-3x slower to compute than the primal pass in models like this.

The biggest thing missing here is batched input. Instead of 10_000 separate evaluations (each with a gradient which the rule for map accumulates) you can run just one.

# As above:

julia> @btime mapall($model, $ra)
  min 872.443 ms, mean 1.088 s (70004 allocations, 161.29 MiB)

julia> @btime Flux.gradient(m->mapall(m,ra),model)
  min 21.929 s, mean 21.929 s (250086 allocations, 3.47 GiB)
((layers = ((weight = [-2.540862078993e6; -539730.7332964173; … ;

# Batched input:

julia> x2 = collect(ra')
1×10000 Matrix{Float64}:
 -1.0  -0.9998  -0.9996  -0.9994  -0.9992  -0.999  …  0.999  0.9992  0.9994  0.9996  0.9998  1.0

julia> m2 = Chain(Dense(W1,b1), Dense(W2,b2), Dense(W3,b3));

julia> @btime sum(abs2, $m2($x2))
  min 493.252 ms, mean 494.542 ms (12 allocations, 155.79 MiB)

julia> @btime gradient(m -> sum(abs2, m($x2)), m2)
  min 767.347 ms, mean 775.584 ms (66 allocations, 233.94 MiB)
((layers = ((weight = [-2.5408620789929945e6; -539730.733296421; …

# One more optimisation pending in Flux:

julia> (a::NeuralNet{typeof(identity)})(x::AbstractVecOrMat) = muladd(a.weight, x, a.bias)

julia> @btime sum(abs2, $model($x2));
  min 258.594 ms, mean 261.037 ms (6 allocations, 77.90 MiB)

julia> @btime gradient(m -> sum(abs2, m($x2)), $model);
  min 567.586 ms, mean 573.953 ms (42 allocations, 156.04 MiB)

Many thanks! Like always, the solution is quickly found by the Julia community.