I’m almost unable to compute ant gradients (EDIT 9 hrs later: meant to say “any gradients”, but got gradients for ants) in Zygote. I initially attempted to optimize my model using around a million parameters, but it got stuck computing the first gradient. I simplified the model so now it had only 401108 parameters. Still couldn’t get a single gradient. Finally, I used just 100 data points, waited some more and got my gradient after about 6 minutes.
Six minutes to compute one gradient of a logistic regression?? Also the gradient is with respect to only half of the parameters.
Julia code
using Statistics
import Random, Zygote
const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}
logistic_sigmoid(x::Real) = 1 / (1 + exp(-x))
function loss(
Wmid::AM{<:Real}, Wctx::AM{<:Real},
tok_mid::AV{<:Integer}, tok_ctx::AV{<:Integer}, x::AV{Bool}
)
nll = -mean(
let
pij = @views logistic_sigmoid(Wmid[i, :]' * Wctx[j, :])
nll = Xij * log(pij) + (1 - Xij) * log(1 - pij)
end
for (i, j, Xij) in zip(tok_mid, tok_ctx, x)
)
@info "loss: $nll"
nll
end
function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer)
_loss(Wmid::AV{<:Real}, Wctx::AV{<:Real}) = loss(
reshape(Wmid, nvocab, ndim), reshape(Wctx, nvocab, ndim),
tok_mid, tok_ctx, x
)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples)
x = [trues(ntrues); falses(nsamples - ntrues)]
weights_mid = 0.1randn(rng, Float32, nvocab * ndim)
weights_ctx = 0.1randn(rng, Float32, nvocab * ndim)
@info "Number of parameters:" size(weights_mid) size(weights_ctx) total=(length(weights_mid) + length(weights_ctx))
_loss(weights_mid, weights_ctx)
@timev _loss(weights_mid, weights_ctx)
Zygote.gradient(
Wmid -> _loss(Wmid, weights_ctx), weights_mid
)[1]
end
grad = @timev train(Random.Xoshiro(5), 100277, 100, 2)
@show size(grad)
@info "Gradient sum:" sum(grad)
Julia output
> julia --project test_grad.jl
┌ Info: Number of parameters:
│ size(weights_mid) = (200554,)
│ size(weights_ctx) = (200554,)
└ total = 401108
[ Info: loss: 0.6925831785105732
[ Info: loss: 0.6925831785105732
0.000092 seconds (213 allocations: 9.227 KiB)
elapsed time (ns): 91884.0
gc time (ns): 0
bytes allocated: 9448
pool allocs: 213
non-pool GC allocs: 0
minor collections: 0
full collections: 0
[ Info: loss: 0.6925831785105732
344.948917 seconds (42.56 M allocations: 2.558 GiB, 0.18% gc time, 99.93% compilation time)
elapsed time (ns): 3.44948917302e11
gc time (ns): 608519881
bytes allocated: 2746926240
pool allocs: 42536936
non-pool GC allocs: 967
malloc() calls: 24705
free() calls: 24872
minor collections: 99
full collections: 1
size(grad) = (200554,)
┌ Info: Gradient sum:
└ sum(grad) = -0.018946058381392624
Time: 344.948917 seconds
Python code
import time
import jax, jax.numpy as np, jax.random as rnd
def logistic_sigmoid(x):
return 1 / (1 + np.exp(-x))
@jax.jit
def loss(Wmid, Wctx, tok_mid, tok_ctx, x):
def nll(i, j, Xij):
wi, wj = Wmid[i, :], Wctx[j, :]
pij = logistic_sigmoid(wi.dot(wj))
return Xij * np.log(pij) + (1 - Xij) * np.log(1 - pij)
return -jax.vmap(nll)(tok_mid, tok_ctx, x).mean()
def train(rng, nvocab: int, nsamples: int, ndim: int):
def _loss(Wmid, Wctx):
return loss(
np.reshape(Wmid, (nvocab, ndim)),
np.reshape(Wctx, (nvocab, ndim)),
tok_mid, tok_ctx, x
)
ntrues = nsamples // 2
tok_mid = rnd.choice(rng, np.arange(nvocab), (nsamples, ))
tok_ctx = rnd.choice(rng, np.arange(nvocab), (nsamples, ))
x = np.concatenate([
np.ones(ntrues), np.zeros(nsamples - ntrues)
])
weights_mid = 0.1 * rnd.normal(rng, (nvocab * ndim, ))
weights_ctx = 0.1 * rnd.normal(rng, (nvocab * ndim, ))
print(
"Number of parameters:\n"
f"\t{weights_mid.shape=}\n",
f"\t{weights_ctx.shape=}\n",
f"\ttotal={weights_mid.size + weights_ctx.size}"
)
nll = _loss(weights_mid,weights_ctx,)
print("Loss:", nll)
dloss_Wmid = jax.value_and_grad(lambda Wmid: _loss(Wmid, weights_ctx))
dloss_Wctx = jax.value_and_grad(lambda Wctx: _loss(weights_mid, Wctx))
return dloss_Wmid(weights_mid)
train_jitted = jax.jit(lambda key: train(key, 100277, 100, 2))
tbegin = time.time()
value, grad = jax.block_until_ready(train_jitted(rnd.key(5)))
print(f"{grad.shape=}")
print("Done in", time.time() - tbegin, "seconds (with compilation)")
tbegin = time.time()
value, grad = jax.block_until_ready(train_jitted(rnd.key(5)))
print(f"{grad.shape=}")
print("Done in", time.time() - tbegin, "seconds (2nd run)")
print("Gradient sum:", grad.sum())
Python output
> python test_grad.py
Number of parameters:
weights_mid.shape=(200554,)
weights_ctx.shape=(200554,)
total=401108
Loss: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
grad.shape=(200554,)
Done in 0.4434378147125244 seconds (with compilation)
grad.shape=(200554,)
Done in 0.006359100341796875 seconds (2nd run)
Gradient sum: 0.0016649812
Time after compilation: 0.0064 seconds
Question
So Julia seems to be 344.948917 / 0.0064 == 53898
times slower than JAX. What am I doing wrong in my Julia code?
I tried ForwardDiff as well, it’s just as slow.