I need to evaluate a basic dense and a few-layer neural network with up to tens of thousands of inputs at once. The outputs will be then forwarded to another function, which returns a scalar. However, the performance of the automatic gradient is lacking even in the following simple example model mapall, which merely returns the squared sum of the neural network outputs:
using Random
using Flux
using ChainRulesCore
using Random
using LinearAlgebra
using BenchmarkTools
struct NeuralNet{F, W<:AbstractMatrix, B}
weight::W
bias::B
σ::F
end
Flux.@functor NeuralNet
Flux.trainable(a::NeuralNet) = (weight = a.weight, bias = a.bias,)
function (a::NeuralNet)(x::AbstractVecOrMat)
W, b, σ = a.weight, a.bias, a.σ
return σ.(W*x .+ b)
end
function mapall(nn,x)
return sum( (map(t->nn([t,])[1],x)) .^ 2.0)
end
ra = range(-1,1,10000)
N1 = 1000
W1 = randn(N1,1)
b1 = randn(N1)
l1 = NeuralNet(W1,b1,identity)
N2 = 20
W2 = randn(N2,N1)
b2 = randn(N2)
l2 = NeuralNet(W2,b2,identity)
W3 = randn(1,N2)
b3 = randn(1)
l3 = NeuralNet(W3,b3,identity)
model = Chain(l1,l2,l3)
@btime mapall(model,ra)
@btime Flux.gradient(m->mapall(m,ra),model)
Evaluation of mapall takes approximately 90 milliseconds (which is very good), whereas its gradient with respect to biases and weights of the neural network layers takes over 1.1 seconds. For instance,
@btime model([ra[1]])
takes about 8 microseconds to evaluate, while
@btime Flux.gradient(m->m([ra[1]])[1],model)
takes approximately 36 microseconds, only four times more than the forward pass. Is there any hope to improve the gradient calculation performance from 10x slower evaluation (compared to the primal) closer to the 4x or even better? I have understood the gradient is usually 2-3x slower to compute than the primal pass in models like this.