Understanding the performance of Zygote

I remember reading somewhere that the cost of automatic differentiation is usually less than 2 times the cost of evaluating the function itself.

I tried to benchmark a small code and the results that I got were a little confusing.

using BenchmarkTools
using Zygote

function simple(x)
    return sin(x)

function complicated(x)
    return sin(x)^2*exp(x)*x^-2 + cos(x)^3

x = 1:0.01:10

@btime simple.($x);
@btime simple'.($x);

@btime complicated.($x);
@btime complicated'.($x);
14.942 μs (1 allocation: 7.19 KiB)
  21.299 μs (1 allocation: 7.19 KiB)
  47.380 μs (1 allocation: 7.19 KiB)
  672.962 μs (14417 allocations: 1006.73 KiB)

Why is the AD of simple function so much faster than that of the complicated function?
Why are there so many allocations while taking the derivative of the complicated function?


1 Like

It might have to do with the broadcasting. The pullback is faster when the function complicated operates on vectors directly:

complicated_wb(x) = sin.(x) .^2 .* exp.(x) .* x.^(-2) .+ cos.(x) .^ 3
x = 1:0.01:10
 >@btime complicated_wb(x)
104.899 μs (1 allocation: 7.19 KiB)
> @btime pullback(complicated_wb,x)[2](ones(length(x)))[1]
331.299 μs (8230 allocations: 380.52 KiB)
> @btime complicated'.(x)
759.201 μs (14420 allocations: 1006.92 KiB)
1 Like