What is the correct way to autodiff a simple CUDA kernel using Reactant.jl and Enzyme.jl

I’m new to Reactant.jl and Enzyme.jl, and I want to use them to calculate the grad of a CUDA kernel. Here is my script:

using Random, Reactant, CUDA, Test, Enzyme

const ReactantCUDAExt = Base.get_extension(Reactant, :ReactantCUDAExt)

rng = Random.default_rng()
Random.seed!(rng, 0)

Reactant.set_default_backend("gpu")

@testset "Promote CuTraced" begin
  TFT = ReactantCUDAExt.CuTracedRNumber{Float64,1}
  FT = Float64
  @test Reactant.promote_traced_type(TFT, FT) == TFT
  @test Base.promote_type(TFT, FT) == FT
end

function square_kernel!(x, y)
  i = threadIdx().x
  x[i] *= y[i]
  # We don't yet auto lower this via polygeist
  # sync_threads()
  return nothing
end

function square_2kernel!(x)
  i = threadIdx().x
  x[i] = x[i]^2
  return nothing
end

# basic squaring on GPU
function square!(x, y)
  @cuda blocks = 1 threads = length(x) square_kernel!(x, y)
  return nothing
end

function sum_square_1(x)
  sum(x.^2)
end

function sum_square_2(x)
  @cuda blocks = 1 threads = length(x) square_2kernel!(x)
  sum(x)
end

# @testset "Square Kernel" begin
#   oA = collect(1:1:64)
#   A = Reactant.to_rarray(oA)
#   B = Reactant.to_rarray(100 .* oA)
#   @jit square!(A, B)
#   @test all(Array(A) .≈ (oA .* oA .* 100))
#   @test all(Array(B) .≈ (oA .* 100))
# end

@testset "Sum square kernel" begin
  oA = collect(Float32, 1:1:64)
  A = Reactant.to_rarray(oA)
  ∂f_∂A = Enzyme.make_zero(A)

  out = @jit sum_square_1(A)
  @test out ≈ sum(oA .* oA)

  out = @jit sum_square_2(A)
  @test out ≈ sum(oA .* oA)

  out1 = @compile Enzyme.gradient(Reverse, sum_square_1, A)
  out2 = @compile Enzyme.gradient(Reverse, sum_square_2, A)

  @test out1 ≈ out2
end

It failed when trying to calculate the grad of sum_square_2. Here is the error message:

error: could not compute the adjoint for this operation %4 = "enzymexla.kernel_call"(%3, %3, %3, %2, %3, %3, %1, %arg0) <{backend_config = "", fn = @"##call__Z15square_2kernel_13CuTracedArrayI7Float32Li1ELi1E5_64__E#384", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>], xla_side_effect_free}> : (tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<64xf32>) -> tensor<64xf32>
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_gHZpOE/module_004_qIaR_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/B1BXA/src/mlir/IR/Pass.jl:119
Sum square kernel: Error During Test at /home/ubuntu/project/julia/demo/test1.jl:57
  Got exception outside of a @test
  "failed to run pass manager on module"
  Stacktrace:
    [1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
      @ Reactant.MLIR.IR ~/.julia/packages/Reactant/B1BXA/src/mlir/IR/Pass.jl:163
    [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
      @ Reactant.Compiler ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:1281
    [3] run_pass_pipeline!
      @ ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:1276 [inlined]
    [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, Tuple{Vararg{Symbol, var"#s1732"}} where var"#s1732", Tuple{Vararg{Int64, N}} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
      @ Reactant.Compiler ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:1721
    [5] compile_mlir! (repeats 2 times)
      @ ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:1536 [inlined]
    [6] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
      @ Reactant.Compiler ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:3447
    [7] compile_xla
      @ ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:3420 [inlined]
    [8] compile(f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, serializable::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
      @ Reactant.Compiler ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:3519
    [9] macro expansion
      @ ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:2600 [inlined]
   [10] macro expansion
      @ ~/project/julia/demo/test1.jl:69 [inlined]
   [11] macro expansion
      @ ~/software/julia/1.10.8/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [12] top-level scope
      @ ~/project/julia/demo/test1.jl:58
   [13] include(fname::String)
      @ Base.MainInclude ./client.jl:494
   [14] top-level scope
      @ REPL[17]:1
   [15] eval
      @ ./boot.jl:385 [inlined]
   [16] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
      @ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
   [17] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
      @ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
   [18] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
      @ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
   [19] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
      @ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
   [20] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
   [21] (::Base.var"#1014#1016"{Bool, Bool, Bool})(REPL::Module)
      @ Base ./client.jl:437
   [22] #invokelatest#2
      @ ./essentials.jl:892 [inlined]
   [23] invokelatest
      @ ./essentials.jl:889 [inlined]
   [24] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:421
   [25] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:338
   [26] _start()
      @ Base ./client.jl:557
Test Summary:     | Pass  Error  Total  Time
Sum square kernel |    2      1      3  3.7s
ERROR: LoadError: Some tests did not pass: 2 passed, 0 failed, 1 errored, 0 broken.
in expression starting at /home/ubuntu/project/julia/demo/test1.jl:57

I didn’t find a introduction to autodiff a kernel of CUDA/KernelAbstractions using Reactant.jl and Enzyme.jl. Any advice or tips would mean a lot to me.

1 Like

Try using @compile raise=true to raise the kernel to stablehlo ones but even then, not sure it will work and for now there is no way to add our own rules I hope it will come someday, the issue is that it will never work on TPUs for instance.
however i want to add something, wrting really bad array ops is fine since the MLIR will refactor everything and make nice fuze ops so, if you can write your kernel in terms of slices, sums, map, reduce or things like that go for it it will be really good (which is not the case for Base Julia).
I would take a look at Oceananigans.jl/ext/OceananigansReactantExt at main · CliMA/Oceananigans.jl · GitHub and/or Oceananigans.jl/ext/OceananigansEnzymeExt.jl at main · CliMA/Oceananigans.jl · GitHub for which they converted a lot of kernels so they may have used nice things but I don’t see any kernel related things

update: even raise don’t help

So support for the un-raised kernels are currently in progress. To do differentiation of the raised code, you need the following (both enabling raising, and having raising run before autodiff).

using Random, Reactant, CUDA, Test, Enzyme

const ReactantCUDAExt = Base.get_extension(Reactant, :ReactantCUDAExt)

rng = Random.default_rng()
Random.seed!(rng, 0)

Reactant.set_default_backend("gpu")

@testset "Promote CuTraced" begin
  TFT = ReactantCUDAExt.CuTracedRNumber{Float64,1}
  FT = Float64
  @test Reactant.promote_traced_type(TFT, FT) == TFT
  @test Base.promote_type(TFT, FT) == FT
end

function square_kernel!(x, y)
  i = threadIdx().x
  x[i] *= y[i]
  # We don't yet auto lower this via polygeist
  # sync_threads()
  return nothing
end

function square_2kernel!(x)
  i = threadIdx().x
  x[i] = x[i]^2
  return nothing
end

# basic squaring on GPU
function square!(x, y)
  @cuda blocks = 1 threads = length(x) square_kernel!(x, y)
  return nothing
end

function sum_square_1(x)
  sum(x.^2)
end

function sum_square_2(x)
  @cuda blocks = 1 threads = length(x) square_2kernel!(x)
  sum(x)
end

# @testset "Square Kernel" begin
#   oA = collect(1:1:64)
#   A = Reactant.to_rarray(oA)
#   B = Reactant.to_rarray(100 .* oA)
#   @jit square!(A, B)
#   @test all(Array(A) .≈ (oA .* oA .* 100))
#   @test all(Array(B) .≈ (oA .* 100))
# end

@testset "Sum square kernel" begin
  oA = collect(Float32, 1:1:64)
  A = Reactant.to_rarray(oA)
  ∂f_∂A = Enzyme.make_zero(A)

  out = @jit sum_square_1(A)
  @test out ≈ sum(oA .* oA)

  out = @jit sum_square_2(A)
  @test out ≈ sum(oA .* oA)

  out1 = @jit Enzyme.gradient(Reverse, sum_square_1, A)
  out2 = @jit raise=true raise_first=true Enzyme.gradient(Reverse, sum_square_2, A)

  @test out1[1] ≈ out2[1]
end
1 Like

Thanks, is a solution for writing our reverse / forward pass in the specific case of unraised kernel also on the way ?

yeah, we just need to define the derivative rule for the enzymexla.kernel_call and enzymexla.jit_call ops [if you’re interested in helping, let me know!]

1 Like

Of course, as long I don’t need to touch c++ :stuck_out_tongue:

Would it be possible that if the kernel isn’t raised and the rule is implemented for Enzyme.jl it just compiles the reverse / forward pass as a julia function ?

eventually yes, right now, no

Thanks! That works for me.

But I’m not sure if there are some limitation or unstabilities when trying to autodiff the raised code. I am working on developing a new atmosphere model using Julia. However, when I try to autodiff a raised divergence operator code in my model. It failed with:

ERROR: LoadError: AssertionError: Invalid option raise
Stacktrace:
 [1] compile_call_expr(::Module, ::Function, ::Dict{Symbol, Any}, ::Expr, ::Vararg{Expr})
   @ Reactant.Compiler ~/.julia/packages/Reactant/7y9bj/src/Compiler.jl:673
 [2] var"@jit"(__source__::LineNumberNode, __module__::Module, args::Vararg{Any})
   @ Reactant.Compiler ~/.julia/packages/Reactant/7y9bj/src/Compiler.jl:653

This is part of my script:

Reactant.set_default_backend("gpu")

h_e = Reactant.to_rarray(Array(block.diag.h_e.d))
h_pc = Reactant.to_rarray(Array(block.dstate[1].h_pc.d))
un_e = Reactant.to_rarray(Array(block.dstate[1].un_e.d))
tend_h_pc = Reactant.to_rarray(Array(block.dtend.h_pc.d))
plg_nnb = Reactant.to_rarray(Array(block.domain.full.plg_nnb))
plg_ed = Reactant.to_rarray(Array(block.domain.full.plg_ed))
plg_nr = Reactant.to_rarray(Array(block.domain.full.plg_nr))
edp_leng = Reactant.to_rarray(Array(block.domain.full.edp_leng))
plg_area = Reactant.to_rarray(Array(block.domain.full.plg_area))
r_earth = Reactant.ConcreteRNumber(block.consts.r_earth)
chunksizeval = Reactant.ConcreteRNumber(block.graph.ctx.chunksize)
nv_cmpt = Reactant.ConcreteRNumber(block.domain.full.nv_cmpt)

@inline function div_at_prime_cell!(
  iv               ::INT,
  div              ::AT{FLT, 1},
  normal_u_at_edge ::AT{FLT, 1},
  𝜘_at_edge        ::AT{FLT, 1},
  plg_nnb          ::AT{INT, 1},
  plg_ed           ::AT{INT, 2},
  plg_nr           ::AT{INT, 2},
  edp_leng         ::AT{FLT, 1},
  plg_area         ::AT{FLT, 1},
  r_earth          ::FLT,
) where {INT, FLT}
  local_sum = FLT(0.0)
  for inb = 1:plg_nnb[iv]
    ed_idx = plg_ed[iv, inb]
    nei = plg_nr[iv, inb]
    le  = edp_leng[ed_idx]
    local_sum += nei*𝜘_at_edge[ed_idx]*normal_u_at_edge[ed_idx]*le
  end
  div[iv] = local_sum / (-r_earth*plg_area[iv])
end

function calc_div_at_prime_cell_cuda!(
  div              ::AT{FLT, 1},
  normal_u_at_edge ::AT{FLT, 1},
  𝜘_at_edge        ::AT{FLT, 1},
  plg_nnb          ::AT{INT, 1},
  plg_ed           ::AT{INT, 2},
  plg_nr           ::AT{INT, 2},
  edp_leng         ::AT{FLT, 1},
  plg_area         ::AT{FLT, 1},
  r_earth          ::FLT,
  nv_cmpt          ::Integer
) where {INT, FLT}
  iv = (blockIdx().x - 1) * blockDim().x + threadIdx().x
  @inbounds if iv <= nv_cmpt
    div_at_prime_cell!(
      iv, div, normal_u_at_edge, 𝜘_at_edge,
      plg_nnb, plg_ed, plg_nr,
      edp_leng, plg_area, r_earth
    )
  end
  return nothing
end

function ∂h∂t_kernel(
  h_e,
  un_e,
  tend_h_pc, 
  plg_nnb, 
  plg_ed, 
  plg_nr, 
  edp_leng, 
  plg_area, 
  r_earth, 
  chunksizeval,
  nv_cmpt
)
  dev = KA.get_backend(h_e)
  kernel! = calc_div_at_prime_cell!(dev, (64, 16))
  kernel!(
    tend_h_pc, un_e, h_e, 
    plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, 
    r_earth, 
    chunksizeval,
    nv_cmpt,
    ndrange=(cld(nv_cmpt, chunksizeval), chunksizeval)
  )
  KA.synchronize(dev)
  return sum(tend_h_pc)/length(tend_h_pc)
end

function ∂h∂t_cuda(
  h_e,
  un_e,
  tend_h_pc, 
  plg_nnb, 
  plg_ed, 
  plg_nr, 
  edp_leng, 
  plg_area, 
  r_earth, 
  nv_cmpt
)
  @cuda blocks = 16 threads = Int64(length(h_e)/16) calc_div_at_prime_cell_cuda!(
    tend_h_pc, un_e, h_e, 
    plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, 
    r_earth, 
    nv_cmpt
  )
  return sum(tend_h_pc)/length(tend_h_pc)
end

@testset "Enzyme + KernelAbstractions" begin
  ∂z_∂u = Enzyme.make_zero(un_e)
  ∂z_∂he = Enzyme.make_zero(h_e)
  ∂z_∂tend = Reactant.to_rarray(rand(Base.eltype(tend_h_pc), size(tend_h_pc)...))
  ∂z_∂plg_ed = Enzyme.make_zero(plg_ed)
  ∂z_∂plg_nr = Enzyme.make_zero(plg_nr)
  ∂z_∂edp_leng = Enzyme.make_zero(edp_leng)
  ∂z_∂plg_area = Enzyme.make_zero(plg_area)

  @jit ∂h∂t_kernel(h_e, un_e, tend_h_pc, plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, r_earth, chunksizeval, nv_cmpt)

  @jit raise=true raise_first=true Enzyme.autodiff(Reverse, ∂h∂t_kernel, Const,
    Duplicated(h_e, ∂z_∂he),
    Duplicated(un_e, ∂z_∂u),
    Duplicated(tend_h_pc, ∂z_∂tend),
    Const(plg_nnb),
    Duplicated(plg_ed, ∂z_∂plg_ed),
    Duplicated(plg_nr, ∂z_∂plg_nr),
    Duplicated(edp_leng, ∂z_∂edp_leng),
    Duplicated(plg_area, ∂z_∂plg_area),
    Const(r_earth),
    Const(chunksizeval),
    Const(nv_cmpt)
  )
  @test all(iszero, Array(∂z_∂u)) && error("∂z_∂u is zero")
  @test all(iszero, Array(∂z_∂he)) && error("∂z_∂he is zero")
  @show ∂z_∂u
end

@testset "Enzyme + CUDA" begin
  ∂z_∂u = Enzyme.make_zero(un_e)
  ∂z_∂he = Enzyme.make_zero(h_e)
  ∂z_∂tend = Reactant.to_rarray(rand(Base.eltype(tend_h_pc), size(tend_h_pc)...))
  ∂z_∂plg_ed = Enzyme.make_zero(plg_ed)
  ∂z_∂plg_nr = Enzyme.make_zero(plg_nr)
  ∂z_∂edp_leng = Enzyme.make_zero(edp_leng)
  ∂z_∂plg_area = Enzyme.make_zero(plg_area)
  
  @jit ∂h∂t_cuda(h_e, un_e, tend_h_pc, plg_nnb, plg_ed, plg_nr, edp_leng, plg_area, r_earth, nv_cmpt)

  @jit raise=true raise_first=true Enzyme.autodiff(Reverse, ∂h∂t_cuda, Const,
    Duplicated(h_e, ∂z_∂he),
    Duplicated(un_e, ∂z_∂u),
    Duplicated(tend_h_pc, ∂z_∂tend),
    Const(plg_nnb),
    Duplicated(plg_ed, ∂z_∂plg_ed),
    Duplicated(plg_nr, ∂z_∂plg_nr),
    Duplicated(edp_leng, ∂z_∂edp_leng),
    Duplicated(plg_area, ∂z_∂plg_area),
    Const(r_earth),
    Const(nv_cmpt),
  )
  @test all(iszero, ∂z_∂u) && error("∂z_∂u is zero")
  @test all(iszero, ∂z_∂he) && error("∂z_∂he is zero")
  @show ∂z_∂u
end

Yeah not all kernel are raisable to stablehlo ones (or they are but its really difficult)

Thanks. I’m torn between using Julia or Jax for my programming. It’s really attractive with Julia’s flexibility and high performance. Hope it will be supported in the future.

And I wondered if there are any other way to perform some complex array operations with Reactant.jl and Enzyme.jl when the operations themselves are difficult to vectorize. Just like in Jax, we can use vmap. Is there any funcions to do that in Reactant.jl?

Doesn’t it work with Base.map ?
update : it works but fail to differentiate

And GPU Arrays don’t support scalar indexing, which makes it can’t execute on GPU just like on CPUs. I want to accelerate on GPU, that’s why I tried to use kernels before.

I mean your code was really well written so I don’t even think Reactant will help beside choosing better the number of threads and, if you planed to distribute it it will help a lot but other than that I don’t think it will do so much better for only 1 kernel lauched, the biggest issue is the derivative

You’re right. The derivative is a big deal.

This is only a test case. Actually, there are other kernels lauched during the whole computation. I have distributed it before using MPI and tested the performance between one node CPU and one GPU. For the final target, it’s my plan to execute the whole computation including derivative on GPU and distributed CPUs, and then do something like PINN using Lux.jl. Although Enzyme.jl dose not support reduce operation on CuArrays but supports it on ConcreteRArray in Reactant.jl. That’s the reason I want to have a try with it

if it raises successfully, there should be no problems with the derivative.

the error message you have there (“invalid option raise”) though doesn’t imply raising failed, but that raising doesn’t exist?

what version of reactant are you on?

@yolhan_mannes re the Base.map failure, do you have an example – that should work?

cc @avikpal @Pangoraw

You can use Base.mapslices (some examples in Reactant.jl/test/batching.jl at bd1d2d99af51e1fdbe8a341c7d344a0b951f47c3 · EnzymeAD/Reactant.jl · GitHub).

That said we have some on going work that can take loops/map/mapreduce, etc. code and automatically vectorize it for you without any special functions like vmap. If you have an example that you expect to be auto-vectorized, and reactant doesn’t yet vectorize it, post it as an issue and we can look into that

Nevermind it’s not about map, gradient works fine, but jacobian (Forward and Reverse) don’t

julia> function foo(x)
           y = map(x) do xi
               xi
           end
           return y
       end
foo (generic function with 1 method)

julia> @jit Enzyme.jacobian(Forward,foo,x)
ERROR: AssertionError: Invalid start indices: [0, -1]
Stacktrace:
  [1] slice(x::Reactant.TracedRArray{…}, start_indices::Vector{…}, limit_indices::Vector{…}; strides::Nothing, location::Reactant.MLIR.IR.Location)
    @ Reactant.Ops ~/.julia/packages/Reactant/uU0IF/src/Ops.jl:632
  [2] overload_autodiff(::ForwardMode{…}, f::Const{…}, ::Type{…}, args::BatchDuplicated{…})
    @ Reactant ~/.julia/packages/Reactant/uU0IF/src/Enzyme.jl:481
  [3] autodiff(rmode::ForwardMode{…}, f::Const{…}, rt::Type{…}, args::BatchDuplicated{…})
    @ Reactant ~/.julia/packages/Reactant/uU0IF/src/Overlay.jl:21
  [4] autodiff
    @ ~/.julia/packages/Enzyme/3D7Zv/src/Enzyme.jl:538 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:682 [inlined]
  [6] #gradient#114
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:584 [inlined]
  [7] (::Nothing)(none::Enzyme.var"##gradient#114", none::Nothing, none::Tuple{…}, none::typeof(gradient), none::ForwardMode{…}, none::typeof(foo), none::Reactant.TracedRArray{…}, none::Tuple{})
    @ Reactant ./<missing>:0
  [8] getindex
    @ ./tuple.jl:31 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:682 [inlined]
 [10] #gradient#114
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:584 [inlined]
 [11] call_with_reactant(::Enzyme.var"##gradient#114", ::Nothing, ::Tuple{…}, ::typeof(gradient), ::ForwardMode{…}, ::typeof(foo), ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/uU0IF/src/utils.jl:0
 [12] gradient
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:584 [inlined]
 [13] (::Nothing)(none::typeof(gradient), none::ForwardMode{…}, none::typeof(foo), none::Reactant.TracedRArray{…}, none::Tuple{})
    @ Reactant ./<missing>:0
 [14] macro expansion
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:456 [inlined]
 [15] create_shadows
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:434 [inlined]
 [16] gradient
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:584 [inlined]
 [17] call_with_reactant(::typeof(gradient), ::ForwardMode{…}, ::typeof(foo), ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/uU0IF/src/utils.jl:0
 [18] #jacobian#116
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:793 [inlined]
 [19] jacobian
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:792 [inlined]
 [20] (::Nothing)(none::typeof(jacobian), none::ForwardMode{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [21] jacobian
    @ ~/.julia/packages/Enzyme/3D7Zv/src/sugar.jl:792 [inlined]
 [22] call_with_reactant(::typeof(jacobian), ::ForwardMode{…}, ::typeof(foo), ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/uU0IF/src/utils.jl:0
 [23] make_mlir_fn(f::typeof(jacobian), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/uU0IF/src/TracedUtils.jl:348
 [24] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/uU0IF/src/Compiler.jl:1575
 [25] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/uU0IF/src/Compiler.jl:1542 [inlined]
 [26] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/uU0IF/src/Compiler.jl:3464
 [27] compile_xla
    @ ~/.julia/packages/Reactant/uU0IF/src/Compiler.jl:3437 [inlined]
 [28] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/uU0IF/src/Compiler.jl:3536
 [29] top-level scope
    @ ~/.julia/packages/Reactant/uU0IF/src/Compiler.jl:2614
Some type information was truncated. Use `show(err)` to see complete types.
julia> @jit Enzyme.jacobian(Reverse,foo,x)
ERROR: type Nothing has no field stmts
Stacktrace:
  [1] getproperty
    @ ./Base.jl:49 [inlined]
  [2] rewrite_insts!(ir::Nothing, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…}, guaranteed_error::Bool)
    @ Reactant ~/.julia/packages/Reactant/pQXes/src/utils.jl:555
  [3] call_with_reactant_generator(world::UInt64, source::LineNumberNode, self::Any, redub_arguments::Any)
    @ Reactant ~/.julia/packages/Reactant/pQXes/src/utils.jl:716
  [4] autodiff_thunk
    @ ~/.julia/packages/Enzyme/nV4l9/src/Enzyme.jl:997 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/Enzyme/nV4l9/src/sugar.jl:1153 [inlined]
  [6] jacobian_helper
    @ ~/.julia/packages/Enzyme/nV4l9/src/sugar.jl:797 [inlined]
  [7] (::Nothing)(none::typeof(Enzyme.jacobian_helper), none::ReverseMode{…}, none::Type{…}, none::Val{…}, none::Nothing, none::typeof(foo), none::Tuple{…})
    @ Reactant ./<missing>:0
  [8] call_with_reactant(::typeof(Enzyme.jacobian_helper), ::ReverseMode{…}, ::Type{…}, ::Val{…}, ::Nothing, ::typeof(foo), ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/pQXes/src/utils.jl:519
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/nV4l9/src/sugar.jl:864 [inlined]
 [10] jacobian_helper
    @ ~/.julia/packages/Enzyme/nV4l9/src/sugar.jl:797 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/nV4l9/src/sugar.jl:1251 [inlined]
 [12] jacobian
    @ ~/.julia/packages/Enzyme/nV4l9/src/sugar.jl:1225 [inlined]
 [13] (::Nothing)(none::typeof(jacobian), none::ReverseMode{…}, none::typeof(foo), none::Tuple{…})
    @ Reactant ./<missing>:0
 [14] jacobian
    @ ~/.julia/packages/Enzyme/nV4l9/src/sugar.jl:1225 [inlined]
 [15] call_with_reactant(::typeof(jacobian), ::ReverseMode{…}, ::typeof(foo), ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/pQXes/src/utils.jl:0
 [16] make_mlir_fn(f::typeof(jacobian), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/pQXes/src/TracedUtils.jl:348
 [17] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:1575
 [18] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:1542 [inlined]
 [19] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:3464
 [20] compile_xla
    @ ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:3437 [inlined]
 [21] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:3536
 [22] top-level scope
    @ ~/.julia/packages/Reactant/pQXes/src/Compiler.jl:2614
Some type information was truncated. Use `show(err)` to see complete types.

but it has nothing to do with map.
PS : something I wish would work is

function foo(x)
           y = map(enumerate(x)) do t
               i,xi = t
               xi
           end
           return y
       end

but it triggers scalar indexing

1 Like

We should be able to make these work. map is not overloaded for Iterators yet (see Reactant.jl/src/TracedRArray.jl at main · EnzymeAD/Reactant.jl · GitHub). mapreduce on the other hand is, so we just need to extend the overloaded_map function and the dispatch in Reactant.jl/src/Overlay.jl at e4bb34f6a34189d503cb11804a0b930e15adabfa · EnzymeAD/Reactant.jl · GitHub