What is the correct way to autodiff a simple CUDA kernel using Reactant.jl and Enzyme.jl

Thanks! That works for me.

But I’m not sure if there are some limitation or unstabilities when trying to autodiff the raised code. I am working on developing a new atmosphere model using Julia. However, when I try to autodiff a raised divergence operator code in my model. It failed with:

ERROR: LoadError: AssertionError: Invalid option raise
Stacktrace:
 [1] compile_call_expr(::Module, ::Function, ::Dict{Symbol, Any}, ::Expr, ::Vararg{Expr})
   @ Reactant.Compiler ~/.julia/packages/Reactant/7y9bj/src/Compiler.jl:673
 [2] var"@jit"(__source__::LineNumberNode, __module__::Module, args::Vararg{Any})
   @ Reactant.Compiler ~/.julia/packages/Reactant/7y9bj/src/Compiler.jl:653

This is part of my script:

Reactant.set_default_backend("gpu")

h_e = Reactant.to_rarray(Array(block.diag.h_e.d))
h_pc = Reactant.to_rarray(Array(block.dstate[1].h_pc.d))
un_e = Reactant.to_rarray(Array(block.dstate[1].un_e.d))
tend_h_pc = Reactant.to_rarray(Array(block.dtend.h_pc.d))
plg_nnb = Reactant.to_rarray(Array(block.domain.full.plg_nnb))
plg_ed = Reactant.to_rarray(Array(block.domain.full.plg_ed))
plg_nr = Reactant.to_rarray(Array(block.domain.full.plg_nr))
edp_leng = Reactant.to_rarray(Array(block.domain.full.edp_leng))
plg_area = Reactant.to_rarray(Array(block.domain.full.plg_area))
r_earth = Reactant.ConcreteRNumber(block.consts.r_earth)
chunksizeval = Reactant.ConcreteRNumber(block.graph.ctx.chunksize)
nv_cmpt = Reactant.ConcreteRNumber(block.domain.full.nv_cmpt)

@inline function div_at_prime_cell!(
  iv               ::INT,
  div              ::AT{FLT, 1},
  normal_u_at_edge ::AT{FLT, 1},
  𝜘_at_edge        ::AT{FLT, 1},
  plg_nnb          ::AT{INT, 1},
  plg_ed           ::AT{INT, 2},
  plg_nr           ::AT{INT, 2},
  edp_leng         ::AT{FLT, 1},
  plg_area         ::AT{FLT, 1},
  r_earth          ::FLT,
) where {INT, FLT}
  local_sum = FLT(0.0)
  for inb = 1:plg_nnb[iv]
    ed_idx = plg_ed[iv, inb]
    nei = plg_nr[iv, inb]
    le  = edp_leng[ed_idx]
    local_sum += nei*𝜘_at_edge[ed_idx]*normal_u_at_edge[ed_idx]*le
  end
  div[iv] = local_sum / (-r_earth*plg_area[iv])
end

function calc_div_at_prime_cell_cuda!(
  div              ::AT{FLT, 1},
  normal_u_at_edge ::AT{FLT, 1},
  𝜘_at_edge        ::AT{FLT, 1},
  plg_nnb          ::AT{INT, 1},
  plg_ed           ::AT{INT, 2},
  plg_nr           ::AT{INT, 2},
  edp_leng         ::AT{FLT, 1},
  plg_area         ::AT{FLT, 1},
  r_earth          ::FLT,
  nv_cmpt          ::Integer
) where {INT, FLT}
  iv = (blockIdx().x - 1) * blockDim().x + threadIdx().x
  @inbounds if iv <= nv_cmpt
    div_at_prime_cell!(
      iv, div, normal_u_at_edge, 𝜘_at_edge,
      plg_nnb, plg_ed, plg_nr,
      edp_leng, plg_area, r_earth
    )
  end
  return nothing
end

function ∂h∂t_kernel(
  h_e,
  un_e,
  tend_h_pc, 
  plg_nnb, 
  plg_ed, 
  plg_nr, 
  edp_leng, 
  plg_area, 
  r_earth, 
  chunksizeval,
  nv_cmpt
)
  dev = KA.get_backend(h_e)
  kernel! = calc_div_at_prime_cell!(dev, (64, 16))
  kernel!(
    tend_h_pc, un_e, h_e, 
    plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, 
    r_earth, 
    chunksizeval,
    nv_cmpt,
    ndrange=(cld(nv_cmpt, chunksizeval), chunksizeval)
  )
  KA.synchronize(dev)
  return sum(tend_h_pc)/length(tend_h_pc)
end

function ∂h∂t_cuda(
  h_e,
  un_e,
  tend_h_pc, 
  plg_nnb, 
  plg_ed, 
  plg_nr, 
  edp_leng, 
  plg_area, 
  r_earth, 
  nv_cmpt
)
  @cuda blocks = 16 threads = Int64(length(h_e)/16) calc_div_at_prime_cell_cuda!(
    tend_h_pc, un_e, h_e, 
    plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, 
    r_earth, 
    nv_cmpt
  )
  return sum(tend_h_pc)/length(tend_h_pc)
end

@testset "Enzyme + KernelAbstractions" begin
  ∂z_∂u = Enzyme.make_zero(un_e)
  ∂z_∂he = Enzyme.make_zero(h_e)
  ∂z_∂tend = Reactant.to_rarray(rand(Base.eltype(tend_h_pc), size(tend_h_pc)...))
  ∂z_∂plg_ed = Enzyme.make_zero(plg_ed)
  ∂z_∂plg_nr = Enzyme.make_zero(plg_nr)
  ∂z_∂edp_leng = Enzyme.make_zero(edp_leng)
  ∂z_∂plg_area = Enzyme.make_zero(plg_area)

  @jit ∂h∂t_kernel(h_e, un_e, tend_h_pc, plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, r_earth, chunksizeval, nv_cmpt)

  @jit raise=true raise_first=true Enzyme.autodiff(Reverse, ∂h∂t_kernel, Const,
    Duplicated(h_e, ∂z_∂he),
    Duplicated(un_e, ∂z_∂u),
    Duplicated(tend_h_pc, ∂z_∂tend),
    Const(plg_nnb),
    Duplicated(plg_ed, ∂z_∂plg_ed),
    Duplicated(plg_nr, ∂z_∂plg_nr),
    Duplicated(edp_leng, ∂z_∂edp_leng),
    Duplicated(plg_area, ∂z_∂plg_area),
    Const(r_earth),
    Const(chunksizeval),
    Const(nv_cmpt)
  )
  @test all(iszero, Array(∂z_∂u)) && error("∂z_∂u is zero")
  @test all(iszero, Array(∂z_∂he)) && error("∂z_∂he is zero")
  @show ∂z_∂u
end

@testset "Enzyme + CUDA" begin
  ∂z_∂u = Enzyme.make_zero(un_e)
  ∂z_∂he = Enzyme.make_zero(h_e)
  ∂z_∂tend = Reactant.to_rarray(rand(Base.eltype(tend_h_pc), size(tend_h_pc)...))
  ∂z_∂plg_ed = Enzyme.make_zero(plg_ed)
  ∂z_∂plg_nr = Enzyme.make_zero(plg_nr)
  ∂z_∂edp_leng = Enzyme.make_zero(edp_leng)
  ∂z_∂plg_area = Enzyme.make_zero(plg_area)
  
  @jit ∂h∂t_cuda(h_e, un_e, tend_h_pc, plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, r_earth, nv_cmpt)

  @jit raise=true raise_first=true Enzyme.autodiff(Reverse, ∂h∂t_cuda, Const,
    Duplicated(h_e, ∂z_∂he),
    Duplicated(un_e, ∂z_∂u),
    Duplicated(tend_h_pc, ∂z_∂tend),
    Const(plg_nnb),
    Duplicated(plg_ed, ∂z_∂plg_ed),
    Duplicated(plg_nr, ∂z_∂plg_nr),
    Duplicated(edp_leng, ∂z_∂edp_leng),
    Duplicated(plg_area, ∂z_∂plg_area),
    Const(r_earth),
    Const(nv_cmpt),
  )
  @test all(iszero, ∂z_∂u) && error("∂z_∂u is zero")
  @test all(iszero, ∂z_∂he) && error("∂z_∂he is zero")
  @show ∂z_∂u
end