`Zygote.gradient` is 54000 TIMES slower than `jax.gradient`

Slightly modified code with Enzyme and DifferentiationInterface

using Statistics; import Random
using DifferentiationInterface
import Zygote, Enzyme

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}

logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))

function loss(
	Wmid::AV{<:Real}, Wctx::AV{<:Real},
	nvocab::Integer, ndim::Integer,
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
)
	Wmid, Wctx = reshape(Wmid, nvocab, ndim), reshape(Wctx, nvocab, ndim)

	nll_one((i, j, Xij)) = begin
		pij = @views logistic_sigmoid(Wmid[i, :]' * Wctx[j, :])
		Xij * log(pij) + (1 - Xij) * log(1 - pij)
	end

	nll = -mean(nll_one.(zip(tok_mid, tok_ctx, x)))
	@info "loss: $nll"
	nll
end

function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer, backend)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
	@info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))

	dweights_mid = similar(weights_mid)
	@info "Computing gradient..." backend
	@timev gradient!(
		loss,
		dweights_mid, backend,
		weights_mid, Constant(weights_ctx), Constant(nvocab), Constant(ndim),
		Constant(tok_mid), Constant(tok_ctx), Constant(x)
	)
	dweights_mid
end

grad = train(Random.Xoshiro(5), 100277, 100, 2, AutoZygote())
@show size(grad)
@info "Gradient sum:" sum(grad)

println("\n\n2nd run!!")
grad = train(Random.Xoshiro(5), 100277, 100, 2, AutoZygote())
@show size(grad)
@info "Gradient sum:" sum(grad)
  1. I changed this line weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim) to use a Float32 literal 0.1f0. As pointed out by @yolhan_mannes, my original code used 0.1 which switched all computations to Float64.
  2. I removed all closures except nll_one in loss and used Constant in the call to gradient!.

Results

> julia --project test_grad.jl
┌ Info: Number of parameters:
│   size(weights_mid) = (200554,)
│   size(weights_ctx) = (200554,)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
[ Info: loss: 0.69258314
231.698408 seconds (7.15 M allocations: 656.995 MiB, 0.15% gc time, 99.95% compilation time)
elapsed time (ns):  2.31698408198e11
gc time (ns):       352916979
bytes allocated:    688909032
pool allocs:        7142087
non-pool GC allocs: 346
malloc() calls:     4144
free() calls:       5145
minor collections:  42
full collections:   1
size(grad) = (200554,)
┌ Info: Gradient sum:
└   sum(grad) = -0.018946057f0


2nd run!!
┌ Info: Number of parameters:
│   size(weights_mid) = (200554,)
│   size(weights_ctx) = (200554,)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
[ Info: loss: 0.69258314
  0.090889 seconds (5.10 k allocations: 305.902 MiB, 49.59% gc time)
elapsed time (ns):  9.0888894e7
gc time (ns):       45070100
bytes allocated:    320761112
pool allocs:        4598
non-pool GC allocs: 104
malloc() calls:     402
free() calls:       717
minor collections:  43
full collections:   0
size(grad) = (200554,)
┌ Info: Gradient sum:
└   sum(grad) = -0.018946057f0
  1. The first run takes 231.698408 seconds (99.95% compilation time).
  2. The second run takes 0.090889 seconds. This is 2549 times faster than the first run, but still 0.090889 / 0.0064 == 14 times slower than JAX.
  3. Thus, the vast majority of the time is spent compiling something when evaluating the gradient for the first time.
3 Likes