# Forward- and reverse-mode AD comparisons with JAX

I have been attempting to compute the gradients of a covariance function (kernel) in Julia. I was initially using Zygote, incorrectly since forward-mode is much better suited to my task. However this resulted in some interesting observations, particularly when comparing Zygote.jl / ForwardDiff.jl with JAX in Python.

Forward-mode:
Running the following provides good performance

``````using ForwardDiff

function rbf_kernel(X::AbstractVector, X2=X::AbstractVector)
r2 = (-2X * X2') + (X .+ X2')
return exp.(-0.5 * r2)
end

j = ForwardDiff.derivative(x_ -> kernel([x_], x), x[1])  # compute the Jacobian of k(x_1, x) w.r.t. x_1
return j
end

j = ForwardDiff.derivative.(x_ -> kernel([x_], x), x)  # compute the Jacobian of k(x_j, x) w.r.t. x_j for all j
J = vcat(j...)  # j is of type Vector{Vector}
return sum(J, dims=2)
end

X = collect(range(-1.0, 1.0, length=100));

print("kernel 1:"); @time rbf_kernel(X, X);
print("kernel 2:"); @time rbf_kernel(X, X);

``````
``````kernel 1:  0.287593 seconds (1.50 M allocations: 84.883 MiB, 8.11% gc time, 99.73% compilation time)
kernel 2:  0.000191 seconds (11 allocations: 391.891 KiB)
kernel grad 1:  0.331328 seconds (1.75 M allocations: 98.633 MiB, 4.17% gc time, 99.81% compilation time)
kernel grad 2:  0.000012 seconds (8 allocations: 9.891 KiB)
kernel grad map 1:  0.616478 seconds (3.00 M allocations: 182.207 MiB, 3.39% gc time, 97.09% compilation time)
kernel grad map 2:  0.000298 seconds (813 allocations: 1.046 MiB)
``````

relative to JAX:

``````import jax.numpy as np
from jax import vmap, jacrev, jacfwd, partial, jit
from time import time

@partial(jit, static_argnums=1)
J = jacfwd(kernel)(x[:1], x)
return J

@partial(jit, static_argnums=1)
J = vmap(
jacfwd(kernel),  # jacfwd computes the Jacobian of k(x_j, x) w.r.t. x_j
in_axes=(0, None)  # use vmap's in_axes to specify vectorisation across the first input dimension only.
)(x, x)
return np.sum(J, axis=1)

@jit
def rbf_kernel(X, X2):
X = X.reshape(-1, 1)
X2 = X2.reshape(-1, 1)
r2 = (-2*X @ X2.T) + (X + X2.T)
return np.exp(-0.5 * np.squeeze(r2))

X = np.linspace(-1.0, 1.0, num=100)

t0 = time(); rbf_kernel(X, X).block_until_ready(); t1 = time()  # block_until_ready() avoids lazy execution
print('kernel 1: %2.6f secs' % (t1-t0))
t2 = time(); rbf_kernel(X, X).block_until_ready(); t3 = time()
print('kernel 2: %2.6f secs' % (t3-t2))

print('kernel grad 1: %2.6f secs' % (t1-t0))
print('kernel grad 2: %2.6f secs' % (t3-t2))

print('kernel grad map 1: %2.6f secs' % (t1-t0))
print('kernel grad map 2: %2.6f secs' % (t3-t2))
``````
``````kernel 1: 0.055686 secs
kernel 2: 0.000107 secs
kernel grad map 1: 0.077634 secs
kernel grad map 2: 0.000091 secs
``````

Interestingly, ForwardDiff seems to be faster at computing the gradient, but slower at broadcasting this operation across the dimensions of `X`.

Another interesting point is that, when I initially implemented the method using Zygote, I got the following.

``````using Zygote

function rbf_kernel(X::AbstractVector, X2=X::AbstractVector)
r2 = (-2X * X2') + (X .+ X2')
return exp.(-0.5 * r2)
end

j = jacobian(x_ -> kernel([x_], x), x[1])  # compute the Jacobian of k(x_1, x) w.r.t. x_1
return only(j)
end

j = jacobian.(x_ -> kernel([x_], x), x)  # compute the Jacobian of k(x_j, x) w.r.t. x_j for all j
J = hcat(only.(j)...)  # awkwardly, j is of type Matrix{Tuple{Vector}}
return sum(J', dims=2)
end

X = collect(range(-1.0, 1.0, length=100));

print("kernel 1:"); @time rbf_kernel(X, X);
print("kernel 2:"); @time rbf_kernel(X, X);

``````
``````kernel 1:  0.256675 seconds (1.50 M allocations: 84.772 MiB, 5.18% gc time, 99.68% compilation time)
kernel 2:  0.000178 seconds (11 allocations: 391.891 KiB)
kernel grad 1: 10.803935 seconds (36.53 M allocations: 2.108 GiB, 4.41% gc time, 99.93% compilation time)
kernel grad 2:  0.001086 seconds (11.16 k allocations: 817.922 KiB)
kernel grad map 1:  0.792268 seconds (3.76 M allocations: 237.517 MiB, 7.43% gc time, 84.25% compilation time)
kernel grad map 2:  0.087710 seconds (1.11 M allocations: 79.812 MiB, 7.08% gc time)
``````

whereas simply replacing `jacfwd` with `jacrev` in the JAX snippet gives the following timings,

``````kernel 1: 0.055157 secs
kernel 2: 0.000098 secs
kernel grad map 1: 0.192222 secs
kernel grad map 2: 0.000453 secs
``````

which suggests that JAX was able to do a decent job of computing the gradient even in reverse-mode, whereas Zygote really struggled with this task.

Does anyone have any insight into why:

1. broadcasting here seems roughly 20 times slower than JAX?
2. Zygote struggled so much to compute the gradient efficiently?
1 Like

I think the problem is `vcat(j...)` Splatting a vector is pretty slow.

2 Likes

I just tried removing that line (simply returning `j` rather than applying `vcat` and summing), but this didn’t result in any discernible speedup. It does seem to be the broadcasting that’s the issue there. Perhaps JAX is multi-threading and Julia is using a single thread?

I haven’t look into all of your code, but the first example for instance can be sped up a bit like this:

``````function rbf_kernel(X::AbstractVector, X2::AbstractVector = X)
return exp.((X * X2') .- 0.5 .* (X .+ X2'))
end
``````

You can use BenchmarkTools.jl to better benchmark as well.

3 Likes

Edit: nevermind, just realized that you are executing them twice exactly for that reason.

You are probably mixing compile and run times. Using `BenchmarkTools.jl` like this

``````print("kernel 1:"); @btime rbf_kernel(\$X, \$X);
print("kernel 2:"); @btime rbf_kernel(\$X, \$X);

``````

yields

``````kernel 1:  71.400 μs (11 allocations: 391.73 KiB)
kernel 2:  68.300 μs (11 allocations: 391.73 KiB)
kernel grad 1:  2.167 μs (8 allocations: 9.86 KiB)
kernel grad 2:  2.289 μs (8 allocations: 9.86 KiB)
kernel grad map 1:  260.200 μs (813 allocations: 1.04 MiB)
kernel grad map 2:  228.800 μs (813 allocations: 1.04 MiB)
``````

Here are the `btime` numbers for your `Zygote` kernels:

``````kernel 1:  65.700 μs (11 allocations: 391.73 KiB)
kernel 2:  74.500 μs (11 allocations: 391.73 KiB)
kernel grad 1:  782.800 μs (8682 allocations: 747.48 KiB)
kernel grad 2:  781.600 μs (8681 allocations: 747.47 KiB)
kernel grad map 1:  87.945 ms (866626 allocations: 73.00 MiB)
kernel grad map 2:  88.353 ms (866626 allocations: 73.00 MiB)
``````

Changing the test setup for `ForwardDiff` to use `JET.jl`

``````const X = collect(range(-1.0, 1.0, length=100));

println("kernel:");
@report_opt rbf_kernel(X, X);

``````

yields

``````kernel:
═════ 36 possible errors found ═════
[...]
``````

For `Zygote` we get

``````kernel:
Elapsed time of `kernel_grad_map` is 100 (`length(X)`) times of elapsed time of `kernel_grad`.
I add `Threads.@threads` in `kernel_grad_map` and the time does not change significantly, so I think multithreading is not the root cause here.