Slightly modified code with Enzyme and DifferentiationInterface
using Statistics; import Random
using DifferentiationInterface
import Zygote, Enzyme
const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}
logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))
function loss(
Wmid::AV{<:Real}, Wctx::AV{<:Real},
nvocab::Integer, ndim::Integer,
tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
)
Wmid, Wctx = reshape(Wmid, nvocab, ndim), reshape(Wctx, nvocab, ndim)
nll_one((i, j, Xij)) = begin
pij = @views logistic_sigmoid(Wmid[i, :]' * Wctx[j, :])
Xij * log(pij) + (1 - Xij) * log(1 - pij)
end
nll = -mean(nll_one.(zip(tok_mid, tok_ctx, x)))
@info "loss: $nll"
nll
end
function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer, backend)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples)
x = [trues(ntrues); falses(nsamples - ntrues)]
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
@info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))
dweights_mid = similar(weights_mid)
@info "Computing gradient..." backend
@timev gradient!(
loss,
dweights_mid, backend,
weights_mid, Constant(weights_ctx), Constant(nvocab), Constant(ndim),
Constant(tok_mid), Constant(tok_ctx), Constant(x)
)
dweights_mid
end
grad = train(Random.Xoshiro(5), 100277, 100, 2, AutoZygote())
@show size(grad)
@info "Gradient sum:" sum(grad)
println("\n\n2nd run!!")
grad = train(Random.Xoshiro(5), 100277, 100, 2, AutoZygote())
@show size(grad)
@info "Gradient sum:" sum(grad)
- I changed this line
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
to use a Float32
literal 0.1f0
. As pointed out by @yolhan_mannes, my original code used 0.1
which switched all computations to Float64
.
- I removed all closures except
nll_one
in loss
and used Constant
in the call to gradient!
.
Results
> julia --project test_grad.jl
┌ Info: Number of parameters:
│ size(weights_mid) = (200554,)
│ size(weights_ctx) = (200554,)
└ total = 401108
┌ Info: Computing gradient...
└ backend = AutoZygote()
[ Info: loss: 0.69258314
231.698408 seconds (7.15 M allocations: 656.995 MiB, 0.15% gc time, 99.95% compilation time)
elapsed time (ns): 2.31698408198e11
gc time (ns): 352916979
bytes allocated: 688909032
pool allocs: 7142087
non-pool GC allocs: 346
malloc() calls: 4144
free() calls: 5145
minor collections: 42
full collections: 1
size(grad) = (200554,)
┌ Info: Gradient sum:
└ sum(grad) = -0.018946057f0
2nd run!!
┌ Info: Number of parameters:
│ size(weights_mid) = (200554,)
│ size(weights_ctx) = (200554,)
└ total = 401108
┌ Info: Computing gradient...
└ backend = AutoZygote()
[ Info: loss: 0.69258314
0.090889 seconds (5.10 k allocations: 305.902 MiB, 49.59% gc time)
elapsed time (ns): 9.0888894e7
gc time (ns): 45070100
bytes allocated: 320761112
pool allocs: 4598
non-pool GC allocs: 104
malloc() calls: 402
free() calls: 717
minor collections: 43
full collections: 0
size(grad) = (200554,)
┌ Info: Gradient sum:
└ sum(grad) = -0.018946057f0
- The first run takes 231.698408 seconds (99.95% compilation time).
- The second run takes 0.090889 seconds. This is 2549 times faster than the first run, but still
0.090889 / 0.0064 == 14
times slower than JAX.
- Thus, the vast majority of the time is spent compiling something when evaluating the gradient for the first time.