Speeding up gradient of logpdf

Somehow, I seem to be unable to speed up gradient calculations of a simple logpdf. As a minimal working example, consider the following computation:

using Distributions
import AbstractDifferentiation as AD
import ForwardDiff
import ReverseDiff

N = 1000
c = rand(Poisson(1.2), N)
λ₀ = rand(N)

function loss(c, λ)
    sum(logpdf.(Poisson.(λ), c))
end

# Same, but non-allocating
function loss2(c, λ)
    sum(((c, λ),) -> logpdf(Poisson(λ), c), zip(c, λ))
end

function adtest(ad, loss, c, λ)
    AD.gradient(ad, Base.Fix1(loss, c), λ)
end

Benchmarking gives

julia> using BenchmarkTools

julia> @btime loss2($c, $λ₀);
  73.047 μs (0 allocations: 0 bytes)

julia> @btime adtest($(AD.ReverseDiffBackend()), loss2, $c, $λ₀);
  1.082 ms (24012 allocations: 968.33 KiB)

with loss2 slightly faster than loss and ForwardDiff or Zygote not competitive at N = 1000. Scaling with N seems as expected and there is a huge gap between direct and gradient calculations, i.e., ReverseDiff being about 15-30 times slower than computing the loss depending on N.

Trying the same in Jax, it’s scaling seems to be almost unreal

In [57]: %timeit fwd(c, lam, jit = True).block_until_ready()
27 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [58]: %timeit fwd(c, lam, jit = False).block_until_ready()
1.83 ms ± 603 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [59]: %timeit bwd(c, lam, jit = True).block_until_ready()
9.01 µs ± 25.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [60]: %timeit bwd(c, lam, jit = False).block_until_ready()
11.9 ms ± 279 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

with the gradient calculation even faster than computing the loss when using jit!
Any ideas on how to speed this up a bit, in particular, closing the gap between direct and gradient calculations?

In any case, for completeness the Python code:

import jax.numpy as jnp
import jax.random as random
import jax.scipy as jsp
from jax import jit, grad

N = 1000
key = random.PRNGKey(0)
key, subkey = random.split(key)
c = random.poisson(subkey, 1.2, shape = (N,))
key, subkey = random.split(key)
lam = random.uniform(key, shape = (N,))

def loss(c, lam):
    return jnp.sum(jsp.stats.poisson.logpmf(c, lam))

jloss = jit(loss)
def fwd(c, lam, jit=False):
    if jit:
        return jloss(c, lam)
    else:
        return loss(c, lam)
    
dloss = grad(loss, argnums = 1)
jdloss = jit(dloss)
def bwd(c, lam, jit = False):
    if jit:
        return jdloss(c, lam)
    else:
        return dloss(c, lam)

Hey there! A few remarks:

  • AbstractDifferentiation.jl introduces closures, which can slow things down a lot (the price to pay for generality). If you care about performance, right now I would recommend calling each AD backend directly.
  • For gradients, don’t expect forward mode autodiff (like ForwardDiff.jl) to scale well with N. You need to bet on reverse mode anyway.
  • To close the gap I would either try Enzyme.jl or ReverseDiff.jl with a precomputed tape.

Thanks, will experiment a bit:

  1. That’s a bit unfortunate as I like how easily different autodiff algorithms can be tried when using AbstractDifferentiation.
  2. Yes, reverse mode is the way to go when N gets large.
  3. Indeed, precompiling the tape gets it down to 5x slower … when using mini-batched I will need to make sure that all batches are the same size, right? What else do I need to keep fixed in order to get correct gradients with a precompiled tape?

Setup:

using BenchmarkTools
using Distributions
using Enzyme
using ReverseDiff

N = 1000;
c₀ = rand(Poisson(1.2), N);
λ₀ = rand(N);

function loss(c, λ)
    l = zero(promote_type(eltype(c), eltype(λ)))
    @inbounds @simd for i in eachindex(c, λ)
        l += logpdf(Poisson(λ[i]; check_args=false), c[i])
    end
    return l
end

f = Base.Fix1(loss, c₀)
tape = ReverseDiff.GradientTape(f, ones(N));

Results

julia> @btime f($λ₀);
  30.151 μs (1 allocation: 16 bytes)

julia> @btime ReverseDiff.gradient!(dλ₀, tape, $λ₀) setup = (dλ₀ = zeros($N));
  257.153 μs (0 allocations: 0 bytes)

julia> @btime Enzyme.autodiff(Reverse, f, Active, Duplicated(λ₀, dλ₀)) evals = 1 setup = (dλ₀ = zeros($N));
  142.595 μs (5 allocations: 144 bytes)

Great, will need to see if I can get Enzyme to work on my actual and somewhat larger code. Still impressive what Jax is doing.

My bad, the issue I quoted involved some bad benchmarking, and things might also have improved since then.
Still, that’s part of the reason why @hill and I have been working on a very simple prototype for high-performance cross-backend gradients. Feel free to take a look and tells us what you think!

Yes, and in general avoid any flow control that depends on the precise inputs you give the function. Basically make sure that the same code path as in the tape is taken every time

Thanks, will have a look.
Seems like I found another slight edge:

loss5(c, λ) = mapreduce(ReverseDiff.@forward((c,λ) -> logpdf(Poisson(λ), c)), +, c, λ)
f5 = Base.Fix1(loss5, c₀)
tape5 = ReverseDiff.GradientTape(f5, ones(N));
# Running your loss function and gradient for comparison
julia> @btime f($λ₀);
  37.671 μs (1 allocation: 16 bytes)

julia> @btime ReverseDiff.gradient!(dλ₀, $tape, $λ₀) setup = (dλ₀ = zeros($N));
  264.666 μs (0 allocations: 0 bytes)

julia> @btime f5($λ₀);
  39.995 μs (4 allocations: 8.00 KiB)

julia> @btime ReverseDiff.gradient!(dλ₀, $tape5, $λ₀) setup = (dλ₀ = zeros($N));
  158.477 μs (0 allocations: 0 bytes)

From what I understand, ChainRules has similar rules, but I could not (yet) figure out how to use them with ReverseDiff. Further, @forward seems to be limited to at most 2-args and my actual code is slightly more complicated …

1 Like

Here the function is so simple that you shouldn’t need to define custom rules. Those are mostly useful when you can gain a clear performance boost with manual specification , or when the derivative computation fails (eg due to mutation).

Sure, but that’s not what I was after. Basically, when using ReverseDiff on sum(f.(x)) or mapreduce(f, +, x) it ends up with lots of scalar operations on the tape. Instead, using forward-mode for the applications of f only stores a single sum or mapreduce call on the tape and handles the scalar calls of f much better (that’s why the @forward declaration gave a nice speedup in my example).
The ChainRules.jl docs have a similar example in the section " Writing rules that call back into AD" and I was wondering if there exist already rules of such type for mapreduce or even better sum(broadcasted, ...) and if yes, how to tell ReverseDiff to make use of them?

I suppose there is indeed a rule that takes care of broadcasting, but I’m not 100% sure. In any case you can tell ReverseDiff.jl to use ChainRules.jl with a macro:

However ReverseDiff.jl only accepts one argument, so it won’t save you if you need to manage several. Maybe ComponentArrays.jl + the @forward macro would be the solution?

ReverseDiff will selectively use ForwardDiff for certain operations where it can be more efficient. Broadcasting is one of them, so we can make use of that in the loss function:

poiss_logpdf(c, λ) = logpdf(Poisson(λ; check_args=false), c)

function loss_bc(c, λ)
    l = sum(poiss_logpdf.(c, λ))
    return l
end

Although this does incur some overhead, it makes ReverseDiff gradients faster.

My timings with loss from https://discourse.julialang.org/t/speeding-up-gradient-of-logpdf/109708/4:

  18.848 μs (0 allocations: 0 bytes) # forward
  139.533 μs (0 allocations: 0 bytes) # ReverseDiff gradient
  52.640 μs (0 allocations: 0 bytes) # Enzyme gradient

And with loss_bc:

  19.942 μs (1 allocation: 7.94 KiB)
  59.501 μs (0 allocations: 0 bytes)
  81.432 μs (2 allocations: 15.88 KiB)

Checking the tape for loss_bc shows two operations (broadcast + sum) as expected.

Thanks, that’s what I was hoping for. Still a bit tricky to get down the right path, i.e.,

function loss_bc1(c, λ)
    sum(logpdf.(Poisson.(λ), c))
end
function loss_bc2(c, λ)
    sum(poiss_logpdf.(c, λ))
end

makes a huge difference:

julia> tape_bc1 = ReverseDiff.GradientTape(λ -> loss_bc1(c₀, λ), ones(N));

julia> tape_bc2 = ReverseDiff.GradientTape(λ -> loss_bc2(c₀, λ), ones(N));

julia> @btime ReverseDiff.gradient!(dλ₀, $tape_bc1, $λ₀) setup = (dλ₀ = zeros($N));
  4.315 ms (16489 allocations: 523.27 KiB)

julia> @btime ReverseDiff.gradient!(dλ₀, $tape_bc2, $λ₀) setup = (dλ₀ = zeros($N));
  99.035 μs (0 allocations: 0 bytes)

Yet another example of a function barrier?

Not quite, but a similar idea. The choice to pull poiss_logpdf into its own function was intentional. It avoids Performance regression for BernoulliLogit · Issue #1934 · TuringLang/Turing.jl · GitHub, where trying to broadcast a type’s constructor (without parameters, so Julia thinks it’s a UnionAll instead of a proper, fully qualified type) can be extremely slow when dealing with ADs such as ReverseDiff.

Thanks, this is helps a lot. Now, I can also understand all the red parts in the profile flame graphs much better and hopefully fix some more of them in my actual code.
Also Zygote is now very fast on your function:

julia> @btime Zygote.gradient(λ -> loss_bc2($c₀, λ), $λ₀);
  71.454 μs (48 allocations: 48.98 KiB)

The following repo contains all the logpdfs you may want for this trick:

1 Like

On current Enzyme#main I get the following timings (rerunning all of the above on the same system to more easily compare):

The critical thing for performance on Enzyme just now was marking logabsgamma as having correct memory semantics/custom rule. There’s also several internal buffers Enzyme creates that can [and should] be optimized away which should further improve perf, but Enzyme computing the derivative as fast as the primal seems a good enough start for now.

julia> using Enzyme

julia> using BenchmarkTools

julia> using Distributions

julia> using Enzyme

julia> using ReverseDiff

julia> N = 1000;

julia> c₀ = rand(Poisson(1.2), N);

julia> λ₀ = rand(N);

julia> function loss(c, λ)
           l = zero(promote_type(eltype(c), eltype(λ)))
           @inbounds @simd for i in eachindex(c, λ)
               l += logpdf(Poisson(λ[i]; check_args=false), c[i])
           end
           return l
       end
loss (generic function with 1 method)

julia> f = Base.Fix1(loss, c₀)
(::Base.Fix1{typeof(loss), Vector{Int64}}) (generic function with 1 method)

julia> tape = ReverseDiff.GradientTape(f, ones(N));

julia> @btime f($λ₀);
  25.021 μs (1 allocation: 16 bytes)

julia> @btime ReverseDiff.gradient!(dλ₀, tape, $λ₀) setup = (dλ₀ = zeros($N));
  193.504 μs (0 allocations: 0 bytes)

julia> @btime Enzyme.autodiff(Reverse, f, Active, Duplicated($λ₀, dλ₀)) evals = 1 setup = (dλ₀ = zeros($N));
  25.530 μs (5 allocations: 144 bytes)

julia> @btime Enzyme.autodiff(Reverse, loss, Active, Const($c₀), Duplicated($λ₀, dλ₀)) evals = 1 setup = (dλ₀ = zeros($N));
  25.110 μs (0 allocations: 0 bytes)

julia> function loss_bc(c, λ)
           l = sum(poiss_logpdf.(c, λ))
           return l
       end
loss_bc (generic function with 1 method)

julia> f_bc = Base.Fix1(loss_bc, c₀)
(::Base.Fix1{typeof(loss_bc), Vector{Int64}}) (generic function with 1 method)

julia> tape_bc = ReverseDiff.GradientTape(f_bc, ones(N));

julia> @btime ReverseDiff.gradient!(dλ₀, tape, $λ₀) setup = (dλ₀ = zeros($N));
  131.613 μs (0 allocations: 0 bytes)
julia> function loss_bc1(c, λ)
           sum(logpdf.(Poisson.(λ), c))
       end
loss_bc1 (generic function with 1 method)

julia> function loss_bc2(c, λ)
           sum(poiss_logpdf.(c, λ))
       end
loss_bc2 (generic function with 1 method)

julia> tape_bc1 = ReverseDiff.GradientTape(λ -> loss_bc1(c₀, λ), ones(N));

julia> tape_bc2 = ReverseDiff.GradientTape(λ -> loss_bc2(c₀, λ), ones(N));

julia> @btime ReverseDiff.gradient!(dλ₀, $tape_bc1, $λ₀) setup = (dλ₀ = zeros($N));
  2.908 ms (16489 allocations: 523.27 KiB)

julia> @btime ReverseDiff.gradient!(dλ₀, $tape_bc2, $λ₀) setup = (dλ₀ = zeros($N));
  49.711 μs (0 allocations: 0 bytes)

julia> @btime Zygote.gradient(λ -> loss_bc2($c₀, λ), $λ₀);
  58.461 μs (49 allocations: 49.00 KiB)
1 Like

The speedups are impressive! But I don’t understand what you mean by

The critical thing for performance on Enzyme just now was marking logabsgamma as having correct memory semantics/custom rule

Is this something you fixed in Enzyme yourself before rerunning? If not, what else is different on Enzyme#main?

The impactful change was Enzyme.jl/ext/EnzymeSpecialFunctionsExt.jl at 3db5fe49a20855cb78e6b7c9f825a4dbe631768e · EnzymeAD/Enzyme.jl · GitHub which allows Enzyme to rewrite the calling convention from Julia of the function in a way that makes calling and optimizing it much faster. In particular logabsgamma had a sret (aka returning multiple arguments via memory), which meant that it could not get effectively optimized. This marking allowed our compiler/rewriting system to (correctly) rewrite it in a way that allowed it to be marked as not touching memory/potentially get deleted/not alias other allocations/etc.

1 Like

I see that this is the first Enzyme extension, does that mean many package would need similar tweaks?

So we use that Enzyme.jl-internal table as a way of specifying that a Julia function behaves like a function Enzyme already knows how to handle and marking it as read-only. For example the table also contains sin.

The only thing that was necesary here was marking it as read-only, but while I was at it I also let it use the Enzyme-internal rule.

Specifically, this is an alternate way to specify the differential behavior of something to Julia-level rules. Specifically, Enzyme proper (in contrast to Enzyme.jl) has a table of functions it knows how to handle, that are common to different languages/etc so they do not need to be re-implemented (Enzyme/enzyme/Enzyme/InstructionDerivatives.td at main · EnzymeAD/Enzyme · GitHub). This is one such function.

This function does need a custom rule as it is implemented as an approximation to the true value. Differentiating an approximation != the approximation of the derivative (e.g. d^4/dx^2 [ mysin(x) = x-x^3/3!] != mysin’‘’'(x) ). For this reason it should have a custom derivative.

However, again, the custom derivative behavior here wasn’t the necessary part (since actually this function isn’t differentiated wrt if I recall, per Enzyme’s activity analysis proving it wasn’t part of the differentiable graph). The only part was marking it as read-none and using a faster calling convention than Julia emits normally. We do that already for all internal rules, so I just did that (and added the rule) while I was at it.

2 Likes