Somehow, I seem to be unable to speed up gradient calculations of a simple logpdf. As a minimal working example, consider the following computation:
using Distributions
import AbstractDifferentiation as AD
import ForwardDiff
import ReverseDiff
N = 1000
c = rand(Poisson(1.2), N)
λ₀ = rand(N)
function loss(c, λ)
sum(logpdf.(Poisson.(λ), c))
end
# Same, but non-allocating
function loss2(c, λ)
sum(((c, λ),) -> logpdf(Poisson(λ), c), zip(c, λ))
end
function adtest(ad, loss, c, λ)
AD.gradient(ad, Base.Fix1(loss, c), λ)
end
Benchmarking gives
julia> using BenchmarkTools
julia> @btime loss2($c, $λ₀);
73.047 μs (0 allocations: 0 bytes)
julia> @btime adtest($(AD.ReverseDiffBackend()), loss2, $c, $λ₀);
1.082 ms (24012 allocations: 968.33 KiB)
with loss2
slightly faster than loss
and ForwardDiff
or Zygote
not competitive at N = 1000
. Scaling with N seems as expected and there is a huge gap between direct and gradient calculations, i.e., ReverseDiff being about 15-30 times slower than computing the loss depending on N
.
Trying the same in Jax, it’s scaling seems to be almost unreal
In [57]: %timeit fwd(c, lam, jit = True).block_until_ready()
27 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [58]: %timeit fwd(c, lam, jit = False).block_until_ready()
1.83 ms ± 603 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [59]: %timeit bwd(c, lam, jit = True).block_until_ready()
9.01 µs ± 25.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [60]: %timeit bwd(c, lam, jit = False).block_until_ready()
11.9 ms ± 279 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
with the gradient calculation even faster than computing the loss when using jit
!
Any ideas on how to speed this up a bit, in particular, closing the gap between direct and gradient calculations?
In any case, for completeness the Python code:
import jax.numpy as jnp
import jax.random as random
import jax.scipy as jsp
from jax import jit, grad
N = 1000
key = random.PRNGKey(0)
key, subkey = random.split(key)
c = random.poisson(subkey, 1.2, shape = (N,))
key, subkey = random.split(key)
lam = random.uniform(key, shape = (N,))
def loss(c, lam):
return jnp.sum(jsp.stats.poisson.logpmf(c, lam))
jloss = jit(loss)
def fwd(c, lam, jit=False):
if jit:
return jloss(c, lam)
else:
return loss(c, lam)
dloss = grad(loss, argnums = 1)
jdloss = jit(dloss)
def bwd(c, lam, jit = False):
if jit:
return jdloss(c, lam)
else:
return dloss(c, lam)