I’m new to Reactant.jl and Enzyme.jl, and I want to use them to calculate the grad of a CUDA kernel. Here is my script:
using Random, Reactant, CUDA, Test, Enzyme
const ReactantCUDAExt = Base.get_extension(Reactant, :ReactantCUDAExt)
rng = Random.default_rng()
Random.seed!(rng, 0)
Reactant.set_default_backend("gpu")
@testset "Promote CuTraced" begin
TFT = ReactantCUDAExt.CuTracedRNumber{Float64,1}
FT = Float64
@test Reactant.promote_traced_type(TFT, FT) == TFT
@test Base.promote_type(TFT, FT) == FT
end
function square_kernel!(x, y)
i = threadIdx().x
x[i] *= y[i]
# We don't yet auto lower this via polygeist
# sync_threads()
return nothing
end
function square_2kernel!(x)
i = threadIdx().x
x[i] = x[i]^2
return nothing
end
# basic squaring on GPU
function square!(x, y)
@cuda blocks = 1 threads = length(x) square_kernel!(x, y)
return nothing
end
function sum_square_1(x)
sum(x.^2)
end
function sum_square_2(x)
@cuda blocks = 1 threads = length(x) square_2kernel!(x)
sum(x)
end
# @testset "Square Kernel" begin
# oA = collect(1:1:64)
# A = Reactant.to_rarray(oA)
# B = Reactant.to_rarray(100 .* oA)
# @jit square!(A, B)
# @test all(Array(A) .≈ (oA .* oA .* 100))
# @test all(Array(B) .≈ (oA .* 100))
# end
@testset "Sum square kernel" begin
oA = collect(Float32, 1:1:64)
A = Reactant.to_rarray(oA)
∂f_∂A = Enzyme.make_zero(A)
out = @jit sum_square_1(A)
@test out ≈ sum(oA .* oA)
out = @jit sum_square_2(A)
@test out ≈ sum(oA .* oA)
out1 = @compile Enzyme.gradient(Reverse, sum_square_1, A)
out2 = @compile Enzyme.gradient(Reverse, sum_square_2, A)
@test out1 ≈ out2
end
It failed when trying to calculate the grad of sum_square_2. Here is the error message:
error: could not compute the adjoint for this operation %4 = "enzymexla.kernel_call"(%3, %3, %3, %2, %3, %3, %1, %arg0) <{backend_config = "", fn = @"##call__Z15square_2kernel_13CuTracedArrayI7Float32Li1ELi1E5_64__E#384", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>], xla_side_effect_free}> : (tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<64xf32>) -> tensor<64xf32>
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_gHZpOE/module_004_qIaR_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/B1BXA/src/mlir/IR/Pass.jl:119
Sum square kernel: Error During Test at /home/ubuntu/project/julia/demo/test1.jl:57
Got exception outside of a @test
"failed to run pass manager on module"
Stacktrace:
[1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/B1BXA/src/mlir/IR/Pass.jl:163
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:1281
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:1276 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, Tuple{Vararg{Symbol, var"#s1732"}} where var"#s1732", Tuple{Vararg{Int64, N}} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:1721
[5] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:1536 [inlined]
[6] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
@ Reactant.Compiler ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:3447
[7] compile_xla
@ ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:3420 [inlined]
[8] compile(f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, serializable::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
@ Reactant.Compiler ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:3519
[9] macro expansion
@ ~/.julia/packages/Reactant/B1BXA/src/Compiler.jl:2600 [inlined]
[10] macro expansion
@ ~/project/julia/demo/test1.jl:69 [inlined]
[11] macro expansion
@ ~/software/julia/1.10.8/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
[12] top-level scope
@ ~/project/julia/demo/test1.jl:58
[13] include(fname::String)
@ Base.MainInclude ./client.jl:494
[14] top-level scope
@ REPL[17]:1
[15] eval
@ ./boot.jl:385 [inlined]
[16] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
@ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
[17] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
@ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
[18] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
@ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
[19] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
@ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
[20] run_repl(repl::REPL.AbstractREPL, consumer::Any)
@ REPL ~/software/julia/1.10.8/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
[21] (::Base.var"#1014#1016"{Bool, Bool, Bool})(REPL::Module)
@ Base ./client.jl:437
[22] #invokelatest#2
@ ./essentials.jl:892 [inlined]
[23] invokelatest
@ ./essentials.jl:889 [inlined]
[24] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
@ Base ./client.jl:421
[25] exec_options(opts::Base.JLOptions)
@ Base ./client.jl:338
[26] _start()
@ Base ./client.jl:557
Test Summary: | Pass Error Total Time
Sum square kernel | 2 1 3 3.7s
ERROR: LoadError: Some tests did not pass: 2 passed, 0 failed, 1 errored, 0 broken.
in expression starting at /home/ubuntu/project/julia/demo/test1.jl:57
I didn’t find a introduction to autodiff a kernel of CUDA/KernelAbstractions using Reactant.jl and Enzyme.jl. Any advice or tips would mean a lot to me.