`Zygote.gradient` is 54000 TIMES slower than `jax.gradient`

Benchmark time!

Loss function Autograd 1st grad Subsequent grads (mean)
loss_mcabbott Zygote 4.4 s 2.6 ms
loss_mcabbott Mooncake 82.3 s 4.2 ms
loss_yolhan_mannes Zygote 345.3 s 78.5 ms
loss_yolhan_mannes Mooncake 63.3 s 3.7 ms
loss_mcabbott_mean Zygote 4.2 s 79.97 ms
loss_mcabbott_mean Mooncake 79.4 s 4.2 ms

I ran the entire code from scratch (julia code.jl) for each implementation of the loss function. Below, “startup time” means “time-to-first-gradient”.

Observations:

  • Mooncake’s startup times aren’t great.
  • Mooncake’s other timings are impressive.
  • Accumulation like nll += stuff in loss_yolhan_mannes seems to hurt startup times a lot.

Loss functions

function loss_mcabbott_mean(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	Wmidrows = @inbounds eachrow(Wmid)[tok_mid]
	Wctxrows = @inbounds eachrow(Wctx)[tok_ctx]

	-mean([
		let
			pij = logistic_sigmoid(dot(wi, wj))
			Xij * log(pij) + (1 - Xij) * log(1 - pij)
		end
		for (wi, wj, Xij) in zip(Wmidrows, Wctxrows, x)
	])
end

function loss_mcabbott(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	tmp = sum(
		Wmid[tok_mid, :] .* Wctx[tok_ctx, :], dims=2
	) |> vec
	-mean(
		@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp))
	)
end

function loss_yolhan_mannes(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	nll = zero(T)
	@inbounds @fastmath for c in eachindex(x)
		Xij = x[c]
		dotprod = dot(@view(Wmid[tok_mid[c], :]), @view Wctx[tok_ctx[c], :])
		sigm = logistic_sigmoid(dotprod)
		nll += Xij * log(sigm) + (1 - Xij) * log(1 - sigm)
	end
	nll / length(x)
end
Full code
using Statistics, LinearAlgebra; import Random
using DifferentiationInterface
import Zygote, Mooncake

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}

logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))

function loss_mcabbott_mean(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	Wmidrows = @inbounds eachrow(Wmid)[tok_mid]
	Wctxrows = @inbounds eachrow(Wctx)[tok_ctx]

	-mean([
		let
			pij = logistic_sigmoid(dot(wi, wj))
			Xij * log(pij) + (1 - Xij) * log(1 - pij)
		end
		for (wi, wj, Xij) in zip(Wmidrows, Wctxrows, x)
	])
end

function loss_mcabbott(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	tmp = sum(
		Wmid[tok_mid, :] .* Wctx[tok_ctx, :], dims=2
	) |> vec
	-mean(
		@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp))
	)
end

function loss_yolhan_mannes(
	Wmid::AM{T}, Wctx::AM{T},
	tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
	nll = zero(T)
	@inbounds @fastmath for c in eachindex(x)
		Xij = x[c]
		dotprod = dot(@view(Wmid[tok_mid[c], :]), @view Wctx[tok_ctx[c], :])
		sigm = logistic_sigmoid(dotprod)
		nll += Xij * log(sigm) + (1 - Xij) * log(1 - sigm)
	end
	nll / length(x)
end

function train(rng, loss, nvocab::Integer, nsamples::Integer, ndim::Integer, backend, quiet::Bool)
	ntrues = nsamples ÷ 2

	tok_mid = rand(rng, 1:nvocab, nsamples)
	tok_ctx = rand(rng, 1:nvocab, nsamples)
	x = [trues(ntrues); falses(nsamples - ntrues)]

	weights_mid = 0.1f0 .* randn(rng, Float32, nvocab, ndim)
	weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab, ndim)
	quiet || @info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))

	dweights_mid = similar(weights_mid)
	quiet || @info "Computing gradient..." backend
	if !quiet
		@timev gradient!(
			loss,
			dweights_mid, backend,
			weights_mid, Constant(weights_ctx),
			Constant(tok_mid), Constant(tok_ctx), Constant(x)
		)
	else
		gradient!(
			loss,
			dweights_mid, backend,
			weights_mid, Constant(weights_ctx),
			Constant(tok_mid), Constant(tok_ctx), Constant(x)
		)
	end
	dweights_mid
end

using BenchmarkTools

function main(loss)
	@info "CODE: $loss; backend: Zygote"
	backend = AutoZygote()
	grad = train(Random.Xoshiro(5), loss, 100277, 100, 2, backend, false)
	@info "Gradient info:" size(grad) sum(grad)

	@show mean(
		@benchmark let
			train(Random.Xoshiro(5), $loss, 100277, 100, 2, $backend, true)
		end
	)

	println("\n\n")
	@info "CODE: $loss; backend: Mooncake"
	backend = AutoMooncake(config=nothing)
	grad = train(Random.Xoshiro(5), loss, 100277, 100, 2, backend, false)
	@info "Gradient info:" size(grad) sum(grad)

	@show mean(
		@benchmark let
			train(Random.Xoshiro(5), $loss, 100277, 100, 2, $backend, true)
		end
	)
end

main(loss_mcabbott_mean)
Full output
[ Info: CODE: loss_mcabbott; backend: Zygote
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
  4.433219 seconds (8.67 M allocations: 440.679 MiB, 1.83% gc time, 99.98% compilation time)
elapsed time (ns):  4.433218935e9
gc time (ns):       81026940
bytes allocated:    462085232
pool allocs:        8661685
non-pool GC allocs: 332
malloc() calls:     9061
free() calls:       7113
minor collections:  2
full collections:   0
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = -0.018946059f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(2.613 ms)


[ Info: CODE: loss_mcabbott; backend: Mooncake
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoMooncake{Nothing}(nothing)
 82.335135 seconds (164.97 M allocations: 8.256 GiB, 2.64% gc time, 84.53% compilation time: <1% of which was recompilation)
elapsed time (ns):  8.2335135303e10
gc time (ns):       2172342058
bytes allocated:    8865191920
pool allocs:        164814503
non-pool GC allocs: 5073
malloc() calls:     155178
free() calls:       154675
minor collections:  37
full collections:   2
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = -0.018946059f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(4.221 ms)


[ Info: CODE: loss_yolhan_mannes; backend: Zygote
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
345.317342 seconds (8.39 M allocations: 719.310 MiB, 0.04% gc time, 99.96% compilation time)
elapsed time (ns):  3.45317341779e11
gc time (ns):       152304804
bytes allocated:    754251240
pool allocs:        8381193
non-pool GC allocs: 381
malloc() calls:     6019
free() calls:       8507
minor collections:  27
full collections:   0
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = 0.018946057f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(78.547 ms)



[ Info: CODE: loss_yolhan_mannes; backend: Mooncake
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoMooncake{Nothing}(nothing)
 63.335114 seconds (127.02 M allocations: 6.374 GiB, 2.67% gc time, 82.60% compilation time: <1% of which was recompilation)
elapsed time (ns):  6.3335113863e10
gc time (ns):       1690957639
bytes allocated:    6844269736
pool allocs:        126886207
non-pool GC allocs: 3808
malloc() calls:     126373
free() calls:       123244
minor collections:  28
full collections:   2
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = 0.018946059f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(3.659 ms)

[ Info: CODE: loss_mcabbott_mean; backend: Zygote
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoZygote()
  4.200216 seconds (9.54 M allocations: 483.351 MiB, 2.87% gc time, 95.91% compilation time)
elapsed time (ns):  4.2002158319999995e9
gc time (ns):       120481781
bytes allocated:    506830280
pool allocs:        9530497
non-pool GC allocs: 284
malloc() calls:     6166
free() calls:       8660
minor collections:  3
full collections:   0
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = -0.018946057f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(79.968 ms)



[ Info: CODE: loss_mcabbott_mean; backend: Mooncake
┌ Info: Number of parameters:
│   size(weights_mid) = (100277, 2)
│   size(weights_ctx) = (100277, 2)
└   total = 401108
┌ Info: Computing gradient...
└   backend = AutoMooncake{Nothing}(nothing)
 79.360692 seconds (154.90 M allocations: 7.758 GiB, 2.89% gc time, 84.79% compilation time: <1% of which was recompilation)
elapsed time (ns):  7.9360692289e10
gc time (ns):       2289888095
bytes allocated:    8329889720
pool allocs:        154749525
non-pool GC allocs: 4811
malloc() calls:     146434
free() calls:       140564
minor collections:  34
full collections:   3
┌ Info: Gradient info:
│   size(grad) = (100277, 2)
└   sum(grad) = -0.018946059f0
mean(@benchmark(let
            train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
        end)) = TrialEstimate(4.223 ms)
4 Likes