Hi all,
we are working on a nested sampling algorithm whose proposals are assisted by HMC, and we want to leverage Enzyme.jl
to autodiff
our loglikelihood kernel.
We have a MWE on a simple 4D gaussian likelihood (4 parameters, 2D data).
We cannot make it simpler since we want the loglikelihood to run batched on multiple 4D points (to retain GPU efficiency).
Our current implementation is the following:
using CUDA
using Enzyme
function log_likelihood_kernel!(logLs::CuDeviceVector{Float32},
thetas::CuDeviceMatrix{Float32},
x::CuDeviceMatrix{Float32},
N::Int,
batch_size::Int)
# compute the global thread index (1D grid)
idx = (blockIdx().x - 1) * blockDim().x + threadIdx().x
# only compute if thread corresponds to a sample in the batch
if idx <= batch_size
# extract parameters for this thread/sample
mu1 = thetas[1, idx]
mu2 = thetas[2, idx]
sigma_1 = thetas[3, idx]
sigma_2 = thetas[4, idx]
# compute inverses of standard deviations once to avoid repeated division
inv_sigma1 = 1f0 / (sigma_1)
inv_sigma2 = 1f0 / (sigma_2)
# accumulate quadratic term of Gaussian log-likelihood
quad = 0.0f0
@inbounds for i in 1:N
# compute standardized residuals for both dimensions
d1 = (x[1, i] - mu1) * inv_sigma1
d2 = (x[2, i] - mu2) * inv_sigma2
quad += d1*d1 + d2*d2
end
# write the log-likelihood for this thread/sample
# formula: -0.5 * quad - N*log(sigma1*sigma2) - N*0.5*log(2Ď€)
logLs[idx] = -0.5f0 * quad -
N * (log(sigma_1) + log(sigma_2)) -
N * log(2f0 * Float32(Ď€))
end
return
end
function compute_logLs!(logLs::CuArray{Float32,1},
thetas::CuArray{Float32,2},
x::CuArray{Float32,2})
N = size(x, 2)
batch_size = size(thetas, 2)
threads = 256
blocks = cld(batch_size, threads)
@cuda threads=threads blocks=blocks log_likelihood_kernel!(logLs, thetas, x, N, batch_size)
return nothing # mutating style: result is in `logLs`
end
function per_sample_gradients(thetas::CuArray{Float32,2}, x::CuArray{Float32,2})
batch_size = size(thetas, 2)
logLs = CUDA.zeros(Float32, batch_size)
# allocate gradients
dlogLs = CUDA.ones(Float32, batch_size)
dthetas = CUDA.zeros(Float32, size(thetas))
# differentiate compute_logLs! w.r.t. thetas
Enzyme.autodiff(
Enzyme.Reverse,
compute_logLs!, # function to differentiate
Enzyme.Duplicated(logLs, dlogLs), # we need to differentiate only wrt thetas
Enzyme.Duplicated(thetas, dthetas), # input and gradient buffer
Enzyme.Duplicated(x, CUDA.zeros(eltype(x), size(x))) # x is constant
)
return dthetas # this is 4Ă—batch_size, same shape as thetas
end
### Example usage on GPU
# 3 samples, 4 parameters each
thetas = cu([0.5f0 0.6f0 0.7f0;
0.5f0 0.4f0 0.3f0;
0.1f0 0.2f0 0.15f0;
0.1f0 0.2f0 0.25f0])
# mock data
x = cu([0.0f0 0.5f0 0.3f0 0.4f0 0.5f0;
0.6f0 0.7f0 0.8f0 0.9f0 1.0f0])
# compute gradients
dthetas = per_sample_gradients(thetas, x)
println(Array(dthetas))
### Cross-check with analytical gradient on CPU
# First row of x as a regular Array
x1 = Array(x[1, :])
x2 = Array(x[2, :])
cputhetas = Array(thetas)
for i in 1:size(cputhetas, 2)
mu1 = cputhetas[1, i]
mu2 = cputhetas[2, i]
sigma1 = cputhetas[3, i]
sigma2 = cputhetas[4, i]
N = size(x, 2)
# compute analytical gradients
dmu1 = sum((x1 .- mu1) ./ (sigma1^2))
dmu2 = sum((x2 .- mu2) ./ (sigma2^2))
dsigma1 = sum(((x1 .- mu1).^2) ./ (sigma1^3)) - N / sigma1
dsigma2 = sum(((x2 .- mu2).^2) ./ (sigma2^3)) - N / sigma2
println("Analytical gradients for sample $i:")
println("dmu1: $dmu1, dmu2: $dmu2, dsigma1: $dsigma1, dsigma2: $dsigma2")
end
Analytical and autodiff
’ed results match, so we trust our implementation, though we don’t understand exactly the signature of
Enzyme.autodiff(
Enzyme.Reverse,
compute_logLs!, # function to differentiate
Enzyme.Duplicated(logLs, dlogLs), # we need to differentiate only wrt thetas
Enzyme.Duplicated(thetas, dthetas), # input and gradient buffer
Enzyme.Duplicated(x, CUDA.zeros(eltype(x), size(x))) # x is constant)
In fact:
- We believe we are wasting resources in computing the derivatives of
compute_logLs
function with respect to thelogLs
. - Inspecting the documentation for Reverse AD we haven’t found another version raising no errors on CUDA kernels, and
Enzyme.Const(logLs)
doesn’t seem to play well with CUDA. - The gradients allocated are:
dlogLs = CUDA.ones(Float32, batch_size)
dthetas = CUDA.zeros(Float32, size(thetas))
although we would have expected Reverse to yield the desired derivatives with the opposite allocations for dlogLs
and thetas
, respectively.
We would appreciate some guidance from more experienced Enzyme.jl
and CUDA.jl
users.