Benchmark time!
Loss function | Autograd | 1st grad | Subsequent grads (mean) |
---|---|---|---|
loss_mcabbott |
Zygote | 4.4 s | 2.6 ms |
loss_mcabbott |
Mooncake | 82.3 s | 4.2 ms |
loss_yolhan_mannes |
Zygote | 345.3 s | 78.5 ms |
loss_yolhan_mannes |
Mooncake | 63.3 s | 3.7 ms |
loss_mcabbott_mean |
Zygote | 4.2 s | 79.97 ms |
loss_mcabbott_mean |
Mooncake | 79.4 s | 4.2 ms |
I ran the entire code from scratch (julia code.jl
) for each implementation of the loss function. Below, “startup time” means “time-to-first-gradient”.
Observations:
- Mooncake’s startup times aren’t great.
- Mooncake’s other timings are impressive.
- Accumulation like
nll += stuff
inloss_yolhan_mannes
seems to hurt startup times a lot.
Loss functions
function loss_mcabbott_mean(
Wmid::AM{T}, Wctx::AM{T},
tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
Wmidrows = @inbounds eachrow(Wmid)[tok_mid]
Wctxrows = @inbounds eachrow(Wctx)[tok_ctx]
-mean([
let
pij = logistic_sigmoid(dot(wi, wj))
Xij * log(pij) + (1 - Xij) * log(1 - pij)
end
for (wi, wj, Xij) in zip(Wmidrows, Wctxrows, x)
])
end
function loss_mcabbott(
Wmid::AM{T}, Wctx::AM{T},
tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
tmp = sum(
Wmid[tok_mid, :] .* Wctx[tok_ctx, :], dims=2
) |> vec
-mean(
@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp))
)
end
function loss_yolhan_mannes(
Wmid::AM{T}, Wctx::AM{T},
tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
nll = zero(T)
@inbounds @fastmath for c in eachindex(x)
Xij = x[c]
dotprod = dot(@view(Wmid[tok_mid[c], :]), @view Wctx[tok_ctx[c], :])
sigm = logistic_sigmoid(dotprod)
nll += Xij * log(sigm) + (1 - Xij) * log(1 - sigm)
end
nll / length(x)
end
Full code
using Statistics, LinearAlgebra; import Random
using DifferentiationInterface
import Zygote, Mooncake
const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}
logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))
function loss_mcabbott_mean(
Wmid::AM{T}, Wctx::AM{T},
tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
Wmidrows = @inbounds eachrow(Wmid)[tok_mid]
Wctxrows = @inbounds eachrow(Wctx)[tok_ctx]
-mean([
let
pij = logistic_sigmoid(dot(wi, wj))
Xij * log(pij) + (1 - Xij) * log(1 - pij)
end
for (wi, wj, Xij) in zip(Wmidrows, Wctxrows, x)
])
end
function loss_mcabbott(
Wmid::AM{T}, Wctx::AM{T},
tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
tmp = sum(
Wmid[tok_mid, :] .* Wctx[tok_ctx, :], dims=2
) |> vec
-mean(
@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp))
)
end
function loss_yolhan_mannes(
Wmid::AM{T}, Wctx::AM{T},
tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
) where T<:Real
nll = zero(T)
@inbounds @fastmath for c in eachindex(x)
Xij = x[c]
dotprod = dot(@view(Wmid[tok_mid[c], :]), @view Wctx[tok_ctx[c], :])
sigm = logistic_sigmoid(dotprod)
nll += Xij * log(sigm) + (1 - Xij) * log(1 - sigm)
end
nll / length(x)
end
function train(rng, loss, nvocab::Integer, nsamples::Integer, ndim::Integer, backend, quiet::Bool)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples)
x = [trues(ntrues); falses(nsamples - ntrues)]
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab, ndim)
weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab, ndim)
quiet || @info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))
dweights_mid = similar(weights_mid)
quiet || @info "Computing gradient..." backend
if !quiet
@timev gradient!(
loss,
dweights_mid, backend,
weights_mid, Constant(weights_ctx),
Constant(tok_mid), Constant(tok_ctx), Constant(x)
)
else
gradient!(
loss,
dweights_mid, backend,
weights_mid, Constant(weights_ctx),
Constant(tok_mid), Constant(tok_ctx), Constant(x)
)
end
dweights_mid
end
using BenchmarkTools
function main(loss)
@info "CODE: $loss; backend: Zygote"
backend = AutoZygote()
grad = train(Random.Xoshiro(5), loss, 100277, 100, 2, backend, false)
@info "Gradient info:" size(grad) sum(grad)
@show mean(
@benchmark let
train(Random.Xoshiro(5), $loss, 100277, 100, 2, $backend, true)
end
)
println("\n\n")
@info "CODE: $loss; backend: Mooncake"
backend = AutoMooncake(config=nothing)
grad = train(Random.Xoshiro(5), loss, 100277, 100, 2, backend, false)
@info "Gradient info:" size(grad) sum(grad)
@show mean(
@benchmark let
train(Random.Xoshiro(5), $loss, 100277, 100, 2, $backend, true)
end
)
end
main(loss_mcabbott_mean)
Full output
[ Info: CODE: loss_mcabbott; backend: Zygote
┌ Info: Number of parameters:
│ size(weights_mid) = (100277, 2)
│ size(weights_ctx) = (100277, 2)
└ total = 401108
┌ Info: Computing gradient...
└ backend = AutoZygote()
4.433219 seconds (8.67 M allocations: 440.679 MiB, 1.83% gc time, 99.98% compilation time)
elapsed time (ns): 4.433218935e9
gc time (ns): 81026940
bytes allocated: 462085232
pool allocs: 8661685
non-pool GC allocs: 332
malloc() calls: 9061
free() calls: 7113
minor collections: 2
full collections: 0
┌ Info: Gradient info:
│ size(grad) = (100277, 2)
└ sum(grad) = -0.018946059f0
mean(@benchmark(let
train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
end)) = TrialEstimate(2.613 ms)
[ Info: CODE: loss_mcabbott; backend: Mooncake
┌ Info: Number of parameters:
│ size(weights_mid) = (100277, 2)
│ size(weights_ctx) = (100277, 2)
└ total = 401108
┌ Info: Computing gradient...
└ backend = AutoMooncake{Nothing}(nothing)
82.335135 seconds (164.97 M allocations: 8.256 GiB, 2.64% gc time, 84.53% compilation time: <1% of which was recompilation)
elapsed time (ns): 8.2335135303e10
gc time (ns): 2172342058
bytes allocated: 8865191920
pool allocs: 164814503
non-pool GC allocs: 5073
malloc() calls: 155178
free() calls: 154675
minor collections: 37
full collections: 2
┌ Info: Gradient info:
│ size(grad) = (100277, 2)
└ sum(grad) = -0.018946059f0
mean(@benchmark(let
train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
end)) = TrialEstimate(4.221 ms)
[ Info: CODE: loss_yolhan_mannes; backend: Zygote
┌ Info: Number of parameters:
│ size(weights_mid) = (100277, 2)
│ size(weights_ctx) = (100277, 2)
└ total = 401108
┌ Info: Computing gradient...
└ backend = AutoZygote()
345.317342 seconds (8.39 M allocations: 719.310 MiB, 0.04% gc time, 99.96% compilation time)
elapsed time (ns): 3.45317341779e11
gc time (ns): 152304804
bytes allocated: 754251240
pool allocs: 8381193
non-pool GC allocs: 381
malloc() calls: 6019
free() calls: 8507
minor collections: 27
full collections: 0
┌ Info: Gradient info:
│ size(grad) = (100277, 2)
└ sum(grad) = 0.018946057f0
mean(@benchmark(let
train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
end)) = TrialEstimate(78.547 ms)
[ Info: CODE: loss_yolhan_mannes; backend: Mooncake
┌ Info: Number of parameters:
│ size(weights_mid) = (100277, 2)
│ size(weights_ctx) = (100277, 2)
└ total = 401108
┌ Info: Computing gradient...
└ backend = AutoMooncake{Nothing}(nothing)
63.335114 seconds (127.02 M allocations: 6.374 GiB, 2.67% gc time, 82.60% compilation time: <1% of which was recompilation)
elapsed time (ns): 6.3335113863e10
gc time (ns): 1690957639
bytes allocated: 6844269736
pool allocs: 126886207
non-pool GC allocs: 3808
malloc() calls: 126373
free() calls: 123244
minor collections: 28
full collections: 2
┌ Info: Gradient info:
│ size(grad) = (100277, 2)
└ sum(grad) = 0.018946059f0
mean(@benchmark(let
train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
end)) = TrialEstimate(3.659 ms)
[ Info: CODE: loss_mcabbott_mean; backend: Zygote
┌ Info: Number of parameters:
│ size(weights_mid) = (100277, 2)
│ size(weights_ctx) = (100277, 2)
└ total = 401108
┌ Info: Computing gradient...
└ backend = AutoZygote()
4.200216 seconds (9.54 M allocations: 483.351 MiB, 2.87% gc time, 95.91% compilation time)
elapsed time (ns): 4.2002158319999995e9
gc time (ns): 120481781
bytes allocated: 506830280
pool allocs: 9530497
non-pool GC allocs: 284
malloc() calls: 6166
free() calls: 8660
minor collections: 3
full collections: 0
┌ Info: Gradient info:
│ size(grad) = (100277, 2)
└ sum(grad) = -0.018946057f0
mean(@benchmark(let
train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
end)) = TrialEstimate(79.968 ms)
[ Info: CODE: loss_mcabbott_mean; backend: Mooncake
┌ Info: Number of parameters:
│ size(weights_mid) = (100277, 2)
│ size(weights_ctx) = (100277, 2)
└ total = 401108
┌ Info: Computing gradient...
└ backend = AutoMooncake{Nothing}(nothing)
79.360692 seconds (154.90 M allocations: 7.758 GiB, 2.89% gc time, 84.79% compilation time: <1% of which was recompilation)
elapsed time (ns): 7.9360692289e10
gc time (ns): 2289888095
bytes allocated: 8329889720
pool allocs: 154749525
non-pool GC allocs: 4811
malloc() calls: 146434
free() calls: 140564
minor collections: 34
full collections: 3
┌ Info: Gradient info:
│ size(grad) = (100277, 2)
└ sum(grad) = -0.018946059f0
mean(@benchmark(let
train(Random.Xoshiro(5), $(Expr(:$, :loss)), 100277, 100, 2, $(Expr(:$, :backend)), true)
end)) = TrialEstimate(4.223 ms)