Understanding and optimizing Enzyme.jl Reverse AD on CUDA

Hi all,
we are working on a nested sampling algorithm whose proposals are assisted by HMC, and we want to leverage Enzyme.jl to autodiff our loglikelihood kernel.

We have a MWE on a simple 4D gaussian likelihood (4 parameters, 2D data).
We cannot make it simpler since we want the loglikelihood to run batched on multiple 4D points (to retain GPU efficiency).

Our current implementation is the following:

using CUDA
using Enzyme

function log_likelihood_kernel!(logLs::CuDeviceVector{Float32},
                                thetas::CuDeviceMatrix{Float32},
                                x::CuDeviceMatrix{Float32},
                                N::Int,
                                batch_size::Int)
    # compute the global thread index (1D grid)
    idx = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    # only compute if thread corresponds to a sample in the batch
    if idx <= batch_size
        # extract parameters for this thread/sample
        mu1     = thetas[1, idx]
        mu2     = thetas[2, idx]
        sigma_1 = thetas[3, idx]
        sigma_2 = thetas[4, idx]

        # compute inverses of standard deviations once to avoid repeated division
        inv_sigma1 = 1f0 / (sigma_1)
        inv_sigma2 = 1f0 / (sigma_2)

        # accumulate quadratic term of Gaussian log-likelihood
        quad = 0.0f0
        @inbounds for i in 1:N
            # compute standardized residuals for both dimensions
            d1 = (x[1, i] - mu1) * inv_sigma1
            d2 = (x[2, i] - mu2) * inv_sigma2
            quad += d1*d1 + d2*d2
        end

        # write the log-likelihood for this thread/sample
        # formula: -0.5 * quad - N*log(sigma1*sigma2) - N*0.5*log(2Ď€)
        logLs[idx] = -0.5f0 * quad -
                     N * (log(sigma_1) + log(sigma_2)) -
                     N * log(2f0 * Float32(Ď€))
    end
    return
end

function compute_logLs!(logLs::CuArray{Float32,1},
                        thetas::CuArray{Float32,2},
                        x::CuArray{Float32,2})
    N = size(x, 2)
    batch_size = size(thetas, 2)
    threads = 256
    blocks = cld(batch_size, threads)
    @cuda threads=threads blocks=blocks log_likelihood_kernel!(logLs, thetas, x, N, batch_size)
    return nothing  # mutating style: result is in `logLs`
end

function per_sample_gradients(thetas::CuArray{Float32,2}, x::CuArray{Float32,2})
    batch_size = size(thetas, 2)
    logLs = CUDA.zeros(Float32, batch_size)

    # allocate gradients
    dlogLs = CUDA.ones(Float32, batch_size)
    dthetas = CUDA.zeros(Float32, size(thetas))

    # differentiate compute_logLs! w.r.t. thetas
    Enzyme.autodiff(
        Enzyme.Reverse,
        compute_logLs!,                                   # function to differentiate
        Enzyme.Duplicated(logLs, dlogLs),                 # we need to differentiate only wrt thetas
        Enzyme.Duplicated(thetas, dthetas),               # input and gradient buffer
        Enzyme.Duplicated(x, CUDA.zeros(eltype(x), size(x))) # x is constant
    )

    return dthetas  # this is 4Ă—batch_size, same shape as thetas
end


### Example usage on GPU
# 3 samples, 4 parameters each
thetas = cu([0.5f0 0.6f0 0.7f0;
             0.5f0 0.4f0 0.3f0;
             0.1f0 0.2f0 0.15f0;
             0.1f0 0.2f0 0.25f0])

# mock data
x = cu([0.0f0 0.5f0 0.3f0 0.4f0 0.5f0;
        0.6f0 0.7f0 0.8f0 0.9f0 1.0f0])

# compute gradients
dthetas = per_sample_gradients(thetas, x)
println(Array(dthetas))

### Cross-check with analytical gradient on CPU
# First row of x as a regular Array
x1 = Array(x[1, :])
x2 = Array(x[2, :])
cputhetas = Array(thetas)
for i in 1:size(cputhetas, 2)
    mu1 = cputhetas[1, i]
    mu2 = cputhetas[2, i]
    sigma1 = cputhetas[3, i]
    sigma2 = cputhetas[4, i]
    N = size(x, 2)
    # compute analytical gradients
    dmu1 = sum((x1 .- mu1) ./ (sigma1^2))
    dmu2 = sum((x2 .- mu2) ./ (sigma2^2))
    dsigma1 = sum(((x1 .- mu1).^2) ./ (sigma1^3)) - N / sigma1
    dsigma2 = sum(((x2 .- mu2).^2) ./ (sigma2^3)) - N / sigma2
    println("Analytical gradients for sample $i:")
    println("dmu1: $dmu1, dmu2: $dmu2, dsigma1: $dsigma1, dsigma2: $dsigma2")
end

Analytical and autodiff’ed results match, so we trust our implementation, though we don’t understand exactly the signature of

Enzyme.autodiff(
Enzyme.Reverse,
compute_logLs!,                                   # function to differentiate
Enzyme.Duplicated(logLs, dlogLs),                 # we need to differentiate only wrt thetas
Enzyme.Duplicated(thetas, dthetas),               # input and gradient buffer
Enzyme.Duplicated(x, CUDA.zeros(eltype(x), size(x))) # x is constant)

In fact:

  1. We believe we are wasting resources in computing the derivatives of compute_logLs function with respect to the logLs.
  2. Inspecting the documentation for Reverse AD we haven’t found another version raising no errors on CUDA kernels, and Enzyme.Const(logLs) doesn’t seem to play well with CUDA.
  3. The gradients allocated are:
dlogLs = CUDA.ones(Float32, batch_size)
dthetas = CUDA.zeros(Float32, size(thetas))

although we would have expected Reverse to yield the desired derivatives with the opposite allocations for dlogLs and thetas, respectively.

We would appreciate some guidance from more experienced Enzyme.jl and CUDA.jl users.

I think you could write it as

function per_sample_gradients(thetas::CuArray{Float32,2}, x::CuArray{Float32,2})
           batch_size = size(thetas, 2)
           logLs = CUDA.zeros(Float32, batch_size)

           # allocate gradients
           dlogLs = CUDA.ones(Float32, batch_size)
           dthetas = CUDA.zeros(Float32, size(thetas))# differentiate compute_logLs! w.r.t. thetas
           Enzyme.autodiff(Enzyme.Reverse,compute_logLs!,# function to differentiate
           Enzyme.DuplicatedNoNeed(logLs, dlogLs), # we need to differentiate only wrt thetas
           Enzyme.Duplicated(thetas, dthetas), # input and gradient buffer
           Enzyme.Const(x) # x is constant
           return dthetas  # this is 4Ă—batch_size, same shape as thetas
end

but it seems like there is a bug and we need the x shadow so for now

function per_sample_gradients(thetas::CuArray{Float32,2}, x::CuArray{Float32,2})
           batch_size = size(thetas, 2)
           logLs = CUDA.zeros(Float32, batch_size)

           # allocate gradients
           dlogLs = CUDA.ones(Float32, batch_size)
           dthetas = CUDA.zeros(Float32, size(thetas))# differentiate compute_logLs! w.r.t. thetas
           Enzyme.autodiff(Enzyme.Reverse,compute_logLs!,# function to differentiate
           Enzyme.DuplicatedNoNeed(logLs, dlogLs), # we need to differentiate only wrt thetas
           Enzyme.Duplicated(thetas, dthetas), # input and gradient buffer
           Enzyme.DuplicatedNoNeed(x,CUDA.zeros(eltype(x), size(x))) # x is constant
           return dthetas  # this is 4Ă—batch_size, same shape as thetas
end

It does work with KA, also its cool if its backend agnostic :

using CUDA
using Enzyme
using KernelAbstractions
import KernelAbstractions.Extras: @unroll

@kernel inbounds=true unsafe_indices=true function log_likelihood_kernel!(logLs,thetas,@Const(x),@Const(N),@Const(batch_size))
    idx = @index(Global)
    if idx <= batch_size
        mu1    = thetas[1, idx]
        mu2    = thetas[2, idx]
        sigma_1 = thetas[3, idx]
        sigma_2 = thetas[4, idx]
        inv_sigma1 = 1f0 / (sigma_1)
        inv_sigma2 = 1f0 / (sigma_2)
        quad = 0.0f0
        @unroll for i in 1:N
            d1 = (x[1, i] - mu1) * inv_sigma1
            d2 = (x[2, i] - mu2) * inv_sigma2
            quad += d1*d1 + d2*d2
        end
        logLs[idx] = -0.5f0 * quad - N * (log(sigma_1) + log(sigma_2)) - N * log(2f0 *  Float32(pi))
    end
end

function compute_logLs!(logLs,thetas,x)
    N = size(x, 2)
    batch_size = size(thetas, 2)
    threads = 256
    log_likelihood_kernel!(get_backend(x),threads)(logLs, thetas, x, N, batch_size;ndrange=batch_size)
    return nothing  
end

function per_sample_gradients(thetas, x)
    batch_size = size(thetas, 2)
    logLs = KernelAbstractions.zeros(get_backend(thetas),Float32,batch_size)
    dlogLs = KernelAbstractions.ones(get_backend(thetas),Float32,batch_size)
    dthetas = Enzyme.make_zero(thetas)
    Enzyme.autodiff(
        Enzyme.Reverse,
        compute_logLs!,                               
        Enzyme.DuplicatedNoNeed(logLs, dlogLs),               
        Enzyme.Duplicated(thetas, dthetas),             
        Enzyme.Const(x) 
    )
    return dthetas 
end

thetas = cu([0.5f0 0.6f0 0.7f0;
             0.5f0 0.4f0 0.3f0;
             0.1f0 0.2f0 0.15f0;
             0.1f0 0.2f0 0.25f0])

# mock data
x = cu([0.0f0 0.5f0 0.3f0 0.4f0 0.5f0;
        0.6f0 0.7f0 0.8f0 0.9f0 1.0f0])
# compute gradients
dthetas = per_sample_gradients(thetas, x);
dthetas

PS : you need dlogLs to trace the derivatives.

1 Like

What is the error that rises if you try to use enzyme.const?

A weird one

ERROR: type Const has no field dval
Stacktrace:
  [1] getproperty
    @ .\Base.jl:49 [inlined]
  [2] augmented_primal
    @ C:\Users\yolha\.julia\packages\CUDA\G7Cnt\ext\EnzymeCoreExt.jl:99 [inlined]
  [3] map
    @ .\tuple.jl:357 [inlined]
  [4] map (repeats 2 times)
    @ .\tuple.jl:358 [inlined]
  [5] macro expansion
    @ C:\Users\yolha\.julia\packages\CUDA\G7Cnt\src\compiler\execution.jl:110 [inlined]
  [6] compute_logLs!
    @ c:\Users\yolha\Desktop\juju_tests\Nouveau dossier\main2.jl:46 [inlined]
  [7] diffejulia_compute_logLs__28677wrap
    @ c:\Users\yolha\Desktop\juju_tests\Nouveau dossier\main2.jl:0
  [8] macro expansion
    @ C:\Users\yolha\.julia\packages\Enzyme\LJjsP\src\compiler.jl:5873 [inlined]
  [9] enzyme_call
    @ C:\Users\yolha\.julia\packages\Enzyme\LJjsP\src\compiler.jl:5407 [inlined]
 [10] CombinedAdjointThunk
    @ C:\Users\yolha\.julia\packages\Enzyme\LJjsP\src\compiler.jl:5293 [inlined]
 [11] autodiff
    @ C:\Users\yolha\.julia\packages\Enzyme\LJjsP\src\Enzyme.jl:521 [inlined]
 [12] autodiff
    @ C:\Users\yolha\.julia\packages\Enzyme\LJjsP\src\Enzyme.jl:562 [inlined]
 [13] autodiff
    @ C:\Users\yolha\.julia\packages\Enzyme\LJjsP\src\Enzyme.jl:534 [inlined]
 [14] per_sample_gradients(thetas::CuArray{Float32, 2, CUDA.DeviceMemory}, x::CuArray{Float32, 2, CUDA.DeviceMemory})
    @ Main c:\Users\yolha\Desktop\juju_tests\Nouveau dossier\main2.jl:59
 [15] top-level scope
    @ c:\Users\yolha\Desktop\juju_tests\Nouveau dossier\main2.jl:83

mwe :

using CUDA
using Enzyme

function log_likelihood_kernel!(logLs::CuDeviceVector{Float32},
                                thetas::CuDeviceMatrix{Float32},
                                x::CuDeviceMatrix{Float32},
                                N::Int,
                                batch_size::Int)
    # compute the global thread index (1D grid)
    idx = (blockIdx().x - 1) * blockDim().x + threadIdx().x
    # only compute if thread corresponds to a sample in the batch
    if idx <= batch_size
        # extract parameters for this thread/sample
        mu1     = thetas[1, idx]
        mu2     = thetas[2, idx]
        sigma_1 = thetas[3, idx]
        sigma_2 = thetas[4, idx]

        # compute inverses of standard deviations once to avoid repeated division
        inv_sigma1 = 1f0 / (sigma_1)
        inv_sigma2 = 1f0 / (sigma_2)

        # accumulate quadratic term of Gaussian log-likelihood
        quad = 0.0f0
        @inbounds for i in 1:N
            # compute standardized residuals for both dimensions
            d1 = (x[1, i] - mu1) * inv_sigma1
            d2 = (x[2, i] - mu2) * inv_sigma2
            quad += d1*d1 + d2*d2
        end

        # write the log-likelihood for this thread/sample
        # formula: -0.5 * quad - N*log(sigma1*sigma2) - N*0.5*log(2Ď€)
        logLs[idx] = -0.5f0 * quad - N * (log(sigma_1) + log(sigma_2)) - N * log(2f0 * Float32(pi))
    end
    return
end

function compute_logLs!(logLs::CuArray{Float32,1},
                        thetas::CuArray{Float32,2},
                        x::CuArray{Float32,2})
    N = size(x, 2)
    batch_size = size(thetas, 2)
    threads = 256
    blocks = cld(batch_size, threads)
    @cuda threads=threads blocks=blocks log_likelihood_kernel!(logLs, thetas, x, N, batch_size)
    return nothing  # mutating style: result is in `logLs`
end

function per_sample_gradients(thetas::CuArray{Float32,2}, x::CuArray{Float32,2})
    batch_size = size(thetas, 2)
    logLs = CUDA.zeros(Float32, batch_size)

    # allocate gradients
    dlogLs = CUDA.ones(Float32, batch_size)
    dthetas = CUDA.zeros(Float32, size(thetas))

    # differentiate compute_logLs! w.r.t. thetas
    Enzyme.autodiff(
        Enzyme.Reverse,
        compute_logLs!,                                   # function to differentiate
        Enzyme.DuplicatedNoNeed(logLs, dlogLs),                 # we need to differentiate only wrt thetas
        Enzyme.Duplicated(thetas, dthetas),               # input and gradient buffer
        Enzyme.Const(x) # x is constant
    )

    return dthetas  # this is 4Ă—batch_size, same shape as thetas
end


### Example usage on GPU
# 3 samples, 4 parameters each
thetas = cu([0.5f0 0.6f0 0.7f0;
             0.5f0 0.4f0 0.3f0;
             0.1f0 0.2f0 0.15f0;
             0.1f0 0.2f0 0.25f0])

# mock data
x = cu([0.0f0 0.5f0 0.3f0 0.4f0 0.5f0;
        0.6f0 0.7f0 0.8f0 0.9f0 1.0f0])

# compute gradients
dthetas = per_sample_gradients(thetas, x)