I have been attempting to compute the gradients of a covariance function (kernel) in Julia. I was initially using Zygote, incorrectly since forward-mode is much better suited to my task. However this resulted in some interesting observations, particularly when comparing Zygote.jl / ForwardDiff.jl with JAX in Python.
Forward-mode:
Running the following provides good performance
using ForwardDiff
function rbf_kernel(X::AbstractVector, X2=X::AbstractVector)
r2 = (-2X * X2') + (X .+ X2')
return exp.(-0.5 * r2)
end
function grad_kernel(kernel, x::AbstractVector)
j = ForwardDiff.derivative(x_ -> kernel([x_], x), x[1]) # compute the Jacobian of k(x_1, x) w.r.t. x_1
return j
end
function grad_kernel_map(kernel, x::AbstractVector)
j = ForwardDiff.derivative.(x_ -> kernel([x_], x), x) # compute the Jacobian of k(x_j, x) w.r.t. x_j for all j
J = vcat(j...) # j is of type Vector{Vector}
return sum(J, dims=2)
end
X = collect(range(-1.0, 1.0, length=100));
print("kernel 1:"); @time rbf_kernel(X, X);
print("kernel 2:"); @time rbf_kernel(X, X);
print("kernel grad 1:"); @time grad_kernel(rbf_kernel, X);
print("kernel grad 2:"); @time grad_kernel(rbf_kernel, X);
print("kernel grad map 1:"); @time grad_kernel_map(rbf_kernel, X);
print("kernel grad map 2:"); @time grad_kernel_map(rbf_kernel, X);
kernel 1: 0.287593 seconds (1.50 M allocations: 84.883 MiB, 8.11% gc time, 99.73% compilation time)
kernel 2: 0.000191 seconds (11 allocations: 391.891 KiB)
kernel grad 1: 0.331328 seconds (1.75 M allocations: 98.633 MiB, 4.17% gc time, 99.81% compilation time)
kernel grad 2: 0.000012 seconds (8 allocations: 9.891 KiB)
kernel grad map 1: 0.616478 seconds (3.00 M allocations: 182.207 MiB, 3.39% gc time, 97.09% compilation time)
kernel grad map 2: 0.000298 seconds (813 allocations: 1.046 MiB)
relative to JAX:
import jax.numpy as np
from jax import vmap, jacrev, jacfwd, partial, jit
from time import time
@partial(jit, static_argnums=1)
def grad_kernel(x, kernel):
J = jacfwd(kernel)(x[:1], x)
return J
@partial(jit, static_argnums=1)
def grad_kernel_map(x, kernel):
J = vmap(
jacfwd(kernel), # jacfwd computes the Jacobian of k(x_j, x) w.r.t. x_j
in_axes=(0, None) # use vmap's in_axes to specify vectorisation across the first input dimension only.
)(x, x)
return np.sum(J, axis=1)
@jit
def rbf_kernel(X, X2):
X = X.reshape(-1, 1)
X2 = X2.reshape(-1, 1)
r2 = (-2*X @ X2.T) + (X + X2.T)
return np.exp(-0.5 * np.squeeze(r2))
X = np.linspace(-1.0, 1.0, num=100)
t0 = time(); rbf_kernel(X, X).block_until_ready(); t1 = time() # block_until_ready() avoids lazy execution
print('kernel 1: %2.6f secs' % (t1-t0))
t2 = time(); rbf_kernel(X, X).block_until_ready(); t3 = time()
print('kernel 2: %2.6f secs' % (t3-t2))
t0 = time(); grad_kernel(X, rbf_kernel).block_until_ready(); t1 = time()
print('kernel grad 1: %2.6f secs' % (t1-t0))
t2 = time(); grad_kernel(X, rbf_kernel).block_until_ready(); t3 = time()
print('kernel grad 2: %2.6f secs' % (t3-t2))
t0 = time(); grad_kernel_map(X, rbf_kernel).block_until_ready(); t1 = time()
print('kernel grad map 1: %2.6f secs' % (t1-t0))
t2 = time(); grad_kernel_map(X, rbf_kernel).block_until_ready(); t3 = time()
print('kernel grad map 2: %2.6f secs' % (t3-t2))
kernel 1: 0.055686 secs
kernel 2: 0.000107 secs
kernel grad 1: 0.050636 secs
kernel grad 2: 0.000051 secs
kernel grad map 1: 0.077634 secs
kernel grad map 2: 0.000091 secs
Interestingly, ForwardDiff seems to be faster at computing the gradient, but slower at broadcasting this operation across the dimensions of X
.
Another interesting point is that, when I initially implemented the method using Zygote, I got the following.
using Zygote
function rbf_kernel(X::AbstractVector, X2=X::AbstractVector)
r2 = (-2X * X2') + (X .+ X2')
return exp.(-0.5 * r2)
end
function grad_kernel(kernel, x::AbstractVector)
j = jacobian(x_ -> kernel([x_], x), x[1]) # compute the Jacobian of k(x_1, x) w.r.t. x_1
return only(j)
end
function grad_kernel_map(kernel, x::AbstractVector)
j = jacobian.(x_ -> kernel([x_], x), x) # compute the Jacobian of k(x_j, x) w.r.t. x_j for all j
J = hcat(only.(j)...) # awkwardly, j is of type Matrix{Tuple{Vector}}
return sum(J', dims=2)
end
X = collect(range(-1.0, 1.0, length=100));
print("kernel 1:"); @time rbf_kernel(X, X);
print("kernel 2:"); @time rbf_kernel(X, X);
print("kernel grad 1:"); @time grad_kernel(rbf_kernel, X);
print("kernel grad 2:"); @time grad_kernel(rbf_kernel, X);
print("kernel grad map 1:"); @time grad_kernel_map(rbf_kernel, X);
print("kernel grad map 2:"); @time grad_kernel_map(rbf_kernel, X);
kernel 1: 0.256675 seconds (1.50 M allocations: 84.772 MiB, 5.18% gc time, 99.68% compilation time)
kernel 2: 0.000178 seconds (11 allocations: 391.891 KiB)
kernel grad 1: 10.803935 seconds (36.53 M allocations: 2.108 GiB, 4.41% gc time, 99.93% compilation time)
kernel grad 2: 0.001086 seconds (11.16 k allocations: 817.922 KiB)
kernel grad map 1: 0.792268 seconds (3.76 M allocations: 237.517 MiB, 7.43% gc time, 84.25% compilation time)
kernel grad map 2: 0.087710 seconds (1.11 M allocations: 79.812 MiB, 7.08% gc time)
whereas simply replacing jacfwd
with jacrev
in the JAX snippet gives the following timings,
kernel 1: 0.055157 secs
kernel 2: 0.000098 secs
kernel grad 1: 0.110341 secs
kernel grad 2: 0.000061 secs
kernel grad map 1: 0.192222 secs
kernel grad map 2: 0.000453 secs
which suggests that JAX was able to do a decent job of computing the gradient even in reverse-mode, whereas Zygote really struggled with this task.
Does anyone have any insight into why:
- broadcasting here seems roughly 20 times slower than JAX?
- Zygote struggled so much to compute the gradient efficiently?