Forward- and reverse-mode AD comparisons with JAX

I have been attempting to compute the gradients of a covariance function (kernel) in Julia. I was initially using Zygote, incorrectly since forward-mode is much better suited to my task. However this resulted in some interesting observations, particularly when comparing Zygote.jl / ForwardDiff.jl with JAX in Python.

Running the following provides good performance

using ForwardDiff

function rbf_kernel(X::AbstractVector, X2=X::AbstractVector)
    r2 = (-2X * X2') + (X .+ X2')
    return exp.(-0.5 * r2)

function grad_kernel(kernel, x::AbstractVector)
    j = ForwardDiff.derivative(x_ -> kernel([x_], x), x[1])  # compute the Jacobian of k(x_1, x) w.r.t. x_1
    return j

function grad_kernel_map(kernel, x::AbstractVector)
    j = ForwardDiff.derivative.(x_ -> kernel([x_], x), x)  # compute the Jacobian of k(x_j, x) w.r.t. x_j for all j
    J = vcat(j...)  # j is of type Vector{Vector}
    return sum(J, dims=2)

X = collect(range(-1.0, 1.0, length=100));

print("kernel 1:"); @time rbf_kernel(X, X);
print("kernel 2:"); @time rbf_kernel(X, X);

print("kernel grad 1:"); @time grad_kernel(rbf_kernel, X);
print("kernel grad 2:"); @time grad_kernel(rbf_kernel, X);

print("kernel grad map 1:"); @time grad_kernel_map(rbf_kernel, X);
print("kernel grad map 2:"); @time grad_kernel_map(rbf_kernel, X);
kernel 1:  0.287593 seconds (1.50 M allocations: 84.883 MiB, 8.11% gc time, 99.73% compilation time)
kernel 2:  0.000191 seconds (11 allocations: 391.891 KiB)
kernel grad 1:  0.331328 seconds (1.75 M allocations: 98.633 MiB, 4.17% gc time, 99.81% compilation time)
kernel grad 2:  0.000012 seconds (8 allocations: 9.891 KiB)
kernel grad map 1:  0.616478 seconds (3.00 M allocations: 182.207 MiB, 3.39% gc time, 97.09% compilation time)
kernel grad map 2:  0.000298 seconds (813 allocations: 1.046 MiB)

relative to JAX:

import jax.numpy as np
from jax import vmap, jacrev, jacfwd, partial, jit
from time import time

@partial(jit, static_argnums=1)
def grad_kernel(x, kernel):
    J = jacfwd(kernel)(x[:1], x)
    return J

@partial(jit, static_argnums=1)
def grad_kernel_map(x, kernel):
    J = vmap(
        jacfwd(kernel),  # jacfwd computes the Jacobian of k(x_j, x) w.r.t. x_j
        in_axes=(0, None)  # use vmap's in_axes to specify vectorisation across the first input dimension only.
    )(x, x)
    return np.sum(J, axis=1)

def rbf_kernel(X, X2):
    X = X.reshape(-1, 1)
    X2 = X2.reshape(-1, 1)
    r2 = (-2*X @ X2.T) + (X + X2.T)
    return np.exp(-0.5 * np.squeeze(r2))

X = np.linspace(-1.0, 1.0, num=100)

t0 = time(); rbf_kernel(X, X).block_until_ready(); t1 = time()  # block_until_ready() avoids lazy execution
print('kernel 1: %2.6f secs' % (t1-t0))
t2 = time(); rbf_kernel(X, X).block_until_ready(); t3 = time()
print('kernel 2: %2.6f secs' % (t3-t2))

t0 = time(); grad_kernel(X, rbf_kernel).block_until_ready(); t1 = time()
print('kernel grad 1: %2.6f secs' % (t1-t0))
t2 = time(); grad_kernel(X, rbf_kernel).block_until_ready(); t3 = time()
print('kernel grad 2: %2.6f secs' % (t3-t2))

t0 = time(); grad_kernel_map(X, rbf_kernel).block_until_ready(); t1 = time()
print('kernel grad map 1: %2.6f secs' % (t1-t0))
t2 = time(); grad_kernel_map(X, rbf_kernel).block_until_ready(); t3 = time()
print('kernel grad map 2: %2.6f secs' % (t3-t2))
kernel 1: 0.055686 secs
kernel 2: 0.000107 secs
kernel grad 1: 0.050636 secs
kernel grad 2: 0.000051 secs
kernel grad map 1: 0.077634 secs
kernel grad map 2: 0.000091 secs

Interestingly, ForwardDiff seems to be faster at computing the gradient, but slower at broadcasting this operation across the dimensions of X.

Another interesting point is that, when I initially implemented the method using Zygote, I got the following.

using Zygote

function rbf_kernel(X::AbstractVector, X2=X::AbstractVector)
    r2 = (-2X * X2') + (X .+ X2')
    return exp.(-0.5 * r2)

function grad_kernel(kernel, x::AbstractVector)
    j = jacobian(x_ -> kernel([x_], x), x[1])  # compute the Jacobian of k(x_1, x) w.r.t. x_1
    return only(j)

function grad_kernel_map(kernel, x::AbstractVector)
    j = jacobian.(x_ -> kernel([x_], x), x)  # compute the Jacobian of k(x_j, x) w.r.t. x_j for all j
    J = hcat(only.(j)...)  # awkwardly, j is of type Matrix{Tuple{Vector}}
    return sum(J', dims=2)

X = collect(range(-1.0, 1.0, length=100));

print("kernel 1:"); @time rbf_kernel(X, X);
print("kernel 2:"); @time rbf_kernel(X, X);

print("kernel grad 1:"); @time grad_kernel(rbf_kernel, X);
print("kernel grad 2:"); @time grad_kernel(rbf_kernel, X);

print("kernel grad map 1:"); @time grad_kernel_map(rbf_kernel, X);
print("kernel grad map 2:"); @time grad_kernel_map(rbf_kernel, X);
kernel 1:  0.256675 seconds (1.50 M allocations: 84.772 MiB, 5.18% gc time, 99.68% compilation time)
kernel 2:  0.000178 seconds (11 allocations: 391.891 KiB)
kernel grad 1: 10.803935 seconds (36.53 M allocations: 2.108 GiB, 4.41% gc time, 99.93% compilation time)
kernel grad 2:  0.001086 seconds (11.16 k allocations: 817.922 KiB)
kernel grad map 1:  0.792268 seconds (3.76 M allocations: 237.517 MiB, 7.43% gc time, 84.25% compilation time)
kernel grad map 2:  0.087710 seconds (1.11 M allocations: 79.812 MiB, 7.08% gc time)

whereas simply replacing jacfwd with jacrev in the JAX snippet gives the following timings,

kernel 1: 0.055157 secs
kernel 2: 0.000098 secs
kernel grad 1: 0.110341 secs
kernel grad 2: 0.000061 secs
kernel grad map 1: 0.192222 secs
kernel grad map 2: 0.000453 secs

which suggests that JAX was able to do a decent job of computing the gradient even in reverse-mode, whereas Zygote really struggled with this task.

Does anyone have any insight into why:

  1. broadcasting here seems roughly 20 times slower than JAX?
  2. Zygote struggled so much to compute the gradient efficiently?
I think the problem is vcat(j...) Splatting a vector is pretty slow.


I just tried removing that line (simply returning j rather than applying vcat and summing), but this didn’t result in any discernible speedup. It does seem to be the broadcasting that’s the issue there. Perhaps JAX is multi-threading and Julia is using a single thread?

I haven’t look into all of your code, but the first example for instance can be sped up a bit like this:

function rbf_kernel(X::AbstractVector, X2::AbstractVector = X)
    return exp.((X * X2') .- 0.5 .* (X .+ X2'))

You can use BenchmarkTools.jl to better benchmark as well.


Edit: nevermind, just realized that you are executing them twice exactly for that reason.

You are probably mixing compile and run times. Using BenchmarkTools.jl like this

print("kernel 1:"); @btime rbf_kernel($X, $X);
print("kernel 2:"); @btime rbf_kernel($X, $X);

print("kernel grad 1:"); @btime grad_kernel($rbf_kernel, $X);
print("kernel grad 2:"); @btime grad_kernel($rbf_kernel, $X);

print("kernel grad map 1:"); @btime grad_kernel_map($rbf_kernel, $X);
print("kernel grad map 2:"); @btime grad_kernel_map($rbf_kernel, $X);


kernel 1:  71.400 μs (11 allocations: 391.73 KiB)
kernel 2:  68.300 μs (11 allocations: 391.73 KiB)
kernel grad 1:  2.167 μs (8 allocations: 9.86 KiB)
kernel grad 2:  2.289 μs (8 allocations: 9.86 KiB)
kernel grad map 1:  260.200 μs (813 allocations: 1.04 MiB)
kernel grad map 2:  228.800 μs (813 allocations: 1.04 MiB)

Here are the btime numbers for your Zygote kernels:

kernel 1:  65.700 μs (11 allocations: 391.73 KiB)
kernel 2:  74.500 μs (11 allocations: 391.73 KiB)
kernel grad 1:  782.800 μs (8682 allocations: 747.48 KiB)
kernel grad 2:  781.600 μs (8681 allocations: 747.47 KiB)
kernel grad map 1:  87.945 ms (866626 allocations: 73.00 MiB)
kernel grad map 2:  88.353 ms (866626 allocations: 73.00 MiB)

Changing the test setup for ForwardDiff to use JET.jl

const X = collect(range(-1.0, 1.0, length=100));

@report_opt rbf_kernel(X, X);

println("kernel grad:");
@report_opt grad_kernel(rbf_kernel, X);

println("kernel grad map:");
@report_opt grad_kernel_map(rbf_kernel, X);


kernel grad:
kernel grad map:
═════ 36 possible errors found ═════

For Zygote we get

kernel grad:
kernel grad map:
═════ 50 possible errors found ═════

The running time of ForwardDiff is reasonable.
Elapsed time of kernel_grad_map is 100 (length(X)) times of elapsed time of kernel_grad.
I don’t know why JAX achieves 20x acceleration. Maybe if fuses loop of each element during JIT? You can look at the generated code to provide more insight.
I add Threads.@threads in kernel_grad_map and the time does not change significantly, so I think multithreading is not the root cause here.

Zygote seem super unoptimal here, by looking at the allocation counts. :slight_smile:

Could be a missing adjoint?