Thanks! That works for me.
But I’m not sure if there are some limitation or unstabilities when trying to autodiff the raised code. I am working on developing a new atmosphere model using Julia. However, when I try to autodiff a raised divergence operator code in my model. It failed with:
ERROR: LoadError: AssertionError: Invalid option raise
Stacktrace:
[1] compile_call_expr(::Module, ::Function, ::Dict{Symbol, Any}, ::Expr, ::Vararg{Expr})
@ Reactant.Compiler ~/.julia/packages/Reactant/7y9bj/src/Compiler.jl:673
[2] var"@jit"(__source__::LineNumberNode, __module__::Module, args::Vararg{Any})
@ Reactant.Compiler ~/.julia/packages/Reactant/7y9bj/src/Compiler.jl:653
This is part of my script:
Reactant.set_default_backend("gpu")
h_e = Reactant.to_rarray(Array(block.diag.h_e.d))
h_pc = Reactant.to_rarray(Array(block.dstate[1].h_pc.d))
un_e = Reactant.to_rarray(Array(block.dstate[1].un_e.d))
tend_h_pc = Reactant.to_rarray(Array(block.dtend.h_pc.d))
plg_nnb = Reactant.to_rarray(Array(block.domain.full.plg_nnb))
plg_ed = Reactant.to_rarray(Array(block.domain.full.plg_ed))
plg_nr = Reactant.to_rarray(Array(block.domain.full.plg_nr))
edp_leng = Reactant.to_rarray(Array(block.domain.full.edp_leng))
plg_area = Reactant.to_rarray(Array(block.domain.full.plg_area))
r_earth = Reactant.ConcreteRNumber(block.consts.r_earth)
chunksizeval = Reactant.ConcreteRNumber(block.graph.ctx.chunksize)
nv_cmpt = Reactant.ConcreteRNumber(block.domain.full.nv_cmpt)
@inline function div_at_prime_cell!(
iv ::INT,
div ::AT{FLT, 1},
normal_u_at_edge ::AT{FLT, 1},
𝜘_at_edge ::AT{FLT, 1},
plg_nnb ::AT{INT, 1},
plg_ed ::AT{INT, 2},
plg_nr ::AT{INT, 2},
edp_leng ::AT{FLT, 1},
plg_area ::AT{FLT, 1},
r_earth ::FLT,
) where {INT, FLT}
local_sum = FLT(0.0)
for inb = 1:plg_nnb[iv]
ed_idx = plg_ed[iv, inb]
nei = plg_nr[iv, inb]
le = edp_leng[ed_idx]
local_sum += nei*𝜘_at_edge[ed_idx]*normal_u_at_edge[ed_idx]*le
end
div[iv] = local_sum / (-r_earth*plg_area[iv])
end
function calc_div_at_prime_cell_cuda!(
div ::AT{FLT, 1},
normal_u_at_edge ::AT{FLT, 1},
𝜘_at_edge ::AT{FLT, 1},
plg_nnb ::AT{INT, 1},
plg_ed ::AT{INT, 2},
plg_nr ::AT{INT, 2},
edp_leng ::AT{FLT, 1},
plg_area ::AT{FLT, 1},
r_earth ::FLT,
nv_cmpt ::Integer
) where {INT, FLT}
iv = (blockIdx().x - 1) * blockDim().x + threadIdx().x
@inbounds if iv <= nv_cmpt
div_at_prime_cell!(
iv, div, normal_u_at_edge, 𝜘_at_edge,
plg_nnb, plg_ed, plg_nr,
edp_leng, plg_area, r_earth
)
end
return nothing
end
function ∂h∂t_kernel(
h_e,
un_e,
tend_h_pc,
plg_nnb,
plg_ed,
plg_nr,
edp_leng,
plg_area,
r_earth,
chunksizeval,
nv_cmpt
)
dev = KA.get_backend(h_e)
kernel! = calc_div_at_prime_cell!(dev, (64, 16))
kernel!(
tend_h_pc, un_e, h_e,
plg_nnb, plg_ed, plg_nr, edp_leng, plg_area,
r_earth,
chunksizeval,
nv_cmpt,
ndrange=(cld(nv_cmpt, chunksizeval), chunksizeval)
)
KA.synchronize(dev)
return sum(tend_h_pc)/length(tend_h_pc)
end
function ∂h∂t_cuda(
h_e,
un_e,
tend_h_pc,
plg_nnb,
plg_ed,
plg_nr,
edp_leng,
plg_area,
r_earth,
nv_cmpt
)
@cuda blocks = 16 threads = Int64(length(h_e)/16) calc_div_at_prime_cell_cuda!(
tend_h_pc, un_e, h_e,
plg_nnb, plg_ed, plg_nr, edp_leng, plg_area,
r_earth,
nv_cmpt
)
return sum(tend_h_pc)/length(tend_h_pc)
end
@testset "Enzyme + KernelAbstractions" begin
∂z_∂u = Enzyme.make_zero(un_e)
∂z_∂he = Enzyme.make_zero(h_e)
∂z_∂tend = Reactant.to_rarray(rand(Base.eltype(tend_h_pc), size(tend_h_pc)...))
∂z_∂plg_ed = Enzyme.make_zero(plg_ed)
∂z_∂plg_nr = Enzyme.make_zero(plg_nr)
∂z_∂edp_leng = Enzyme.make_zero(edp_leng)
∂z_∂plg_area = Enzyme.make_zero(plg_area)
@jit ∂h∂t_kernel(h_e, un_e, tend_h_pc, plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, r_earth, chunksizeval, nv_cmpt)
@jit raise=true raise_first=true Enzyme.autodiff(Reverse, ∂h∂t_kernel, Const,
Duplicated(h_e, ∂z_∂he),
Duplicated(un_e, ∂z_∂u),
Duplicated(tend_h_pc, ∂z_∂tend),
Const(plg_nnb),
Duplicated(plg_ed, ∂z_∂plg_ed),
Duplicated(plg_nr, ∂z_∂plg_nr),
Duplicated(edp_leng, ∂z_∂edp_leng),
Duplicated(plg_area, ∂z_∂plg_area),
Const(r_earth),
Const(chunksizeval),
Const(nv_cmpt)
)
@test all(iszero, Array(∂z_∂u)) && error("∂z_∂u is zero")
@test all(iszero, Array(∂z_∂he)) && error("∂z_∂he is zero")
@show ∂z_∂u
end
@testset "Enzyme + CUDA" begin
∂z_∂u = Enzyme.make_zero(un_e)
∂z_∂he = Enzyme.make_zero(h_e)
∂z_∂tend = Reactant.to_rarray(rand(Base.eltype(tend_h_pc), size(tend_h_pc)...))
∂z_∂plg_ed = Enzyme.make_zero(plg_ed)
∂z_∂plg_nr = Enzyme.make_zero(plg_nr)
∂z_∂edp_leng = Enzyme.make_zero(edp_leng)
∂z_∂plg_area = Enzyme.make_zero(plg_area)
@jit ∂h∂t_cuda(h_e, un_e, tend_h_pc, plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, r_earth, nv_cmpt)
@jit raise=true raise_first=true Enzyme.autodiff(Reverse, ∂h∂t_cuda, Const,
Duplicated(h_e, ∂z_∂he),
Duplicated(un_e, ∂z_∂u),
Duplicated(tend_h_pc, ∂z_∂tend),
Const(plg_nnb),
Duplicated(plg_ed, ∂z_∂plg_ed),
Duplicated(plg_nr, ∂z_∂plg_nr),
Duplicated(edp_leng, ∂z_∂edp_leng),
Duplicated(plg_area, ∂z_∂plg_area),
Const(r_earth),
Const(nv_cmpt),
)
@test all(iszero, ∂z_∂u) && error("∂z_∂u is zero")
@test all(iszero, ∂z_∂he) && error("∂z_∂he is zero")
@show ∂z_∂u
end