Hello I tried to get minimal working example for reactant powered kernel abstractions plus Enzyme like
using Pkg
Pkg.activate(".")
using Reactant
using KernelAbstractions
using CUDA
using Enzyme
using Test
# Simple square kernel
@kernel function square_kernel!(y, @Const(x))
i = @index(Global)
@inbounds y[i] = x[i] * x[i]
end
function square(x)
y = similar(x)
backend = KernelAbstractions.get_backend(x)
kernel! = square_kernel!(backend)
kernel!(y, x; ndrange=length(x))
return y
end
function test_reactant_ka_enzyme()
Reactant.set_default_backend("gpu")
x = Reactant.to_rarray(collect(Float32, 1:10))
# Forward pass test
println("Testing forward pass...")
y_compiled = @compile square(x)
println("Forward pass result: ", y_compiled)
# Gradient test
println("Testing gradient...")
dz = Reactant.to_rarray(ones(Float32, 10))
# Define a scalar loss function for gradient testing
function loss_fn(x)
y = square(x)
return sum(y)
end
# Compile gradient
println("Compiling gradient...")
# Reactant @compile expects a function call. We can compile a wrapper or use it directly.
# Correct usage: @compile Enzyme.gradient(Active, loss_fn, x)
# Note: @compile returns a compiled function, which we then call.
# However, Enzyme.gradient returns a tuple.
# Let's define a gradient wrapper to be compiled
function compute_grad(x)
# Use Enzyme.Reverse mode explicitly
# And since we are differentiating a function 'loss_fn' which is a closure,
# we might need to be careful. But let's try standard API first.
return Enzyme.gradient(Enzyme.Reverse, loss_fn, x)[1]
end
compiled_grad_fn = @compile compute_grad(x)
dx = compiled_grad_fn(x)
println("Gradient result: ", dx)
# Verification
expected_dx = 2 .* collect(Float32, 1:10)
println("Expected gradient: ", expected_dx)
@test all(Array(dx) .≈ expected_dx)
println("Test Passed!")
end
test_reactant_ka_enzyme()
it give error
****************************************************************************
* hwloc 2.0.3 received invalid information from the operating system.
*
* Failed with: intersection without inclusion
* while inserting Group0 (P#16 cpuset 0x00000f00) at L1d (P#16 cpuset 0x00000104)
* coming from: linux:sysfs:cluster
*
* The following FAQ entry in the hwloc documentation may help:
* What should I do when hwloc reports "operating system" warnings?
* Otherwise please report this error message to the hwloc user's mailing list,
* along with the files generated by the hwloc-gather-topology script.
*
* hwloc will now ignore this invalid topology information and continue.
****************************************************************************
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1764257009.342895 1121843 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0000 00:00:1764257009.345656 1121842 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
I0000 00:00:1764257009.345859 1121616 service.cc:158] XLA service 0x3785a420 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1764257009.345874 1121616 service.cc:166] StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
I0000 00:00:1764257009.345877 1121616 service.cc:166] StreamExecutor device (1): NVIDIA GeForce RTX 3090, Compute Capability 8.6
I0000 00:00:1764257009.346437 1121616 se_gpu_pjrt_client.cc:1039] Using BFC allocator.
I0000 00:00:1764257009.346460 1121616 gpu_helpers.cc:136] XLA backend allocating 18965004288 bytes on device 0 for BFCAllocator.
I0000 00:00:1764257009.346488 1121616 gpu_helpers.cc:136] XLA backend allocating 18972033024 bytes on device 1 for BFCAllocator.
I0000 00:00:1764257009.346494 1121616 gpu_helpers.cc:177] XLA backend will use up to 6321668096 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1764257009.346500 1121616 gpu_helpers.cc:177] XLA backend will use up to 6324011008 bytes on device 1 for CollectiveBFCAllocator.
W0000 00:00:1764257009.347720 1121616 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0000 00:00:1764257009.351167 1121616 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
I0000 00:00:1764257009.355886 1121616 cuda_dnn.cc:463] Loaded cuDNN version 91400
Testing forward pass...
'86' is not a recognized feature for this target (ignoring feature)
'86' is not a recognized feature for this target (ignoring feature)
'86' is not a recognized feature for this target (ignoring feature)
Forward pass result: Reactant compiled function square (with tag ##square_reactant#237)
Testing gradient...
Compiling gradient...
error: could not compute the adjoint for this operation %5 = "enzymexla.kernel_call"(%3, %3, %3, %4, %3, %3, %1, %2, %arg0) <{backend_config = "", fn = @"##call__Z18gpu_square_kernel_16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_S0_S8_S8_EE13CuTracedArrayI7Float32Li1ELi1E5_10__ESE_#242", operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 2>, output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>], xla_side_effect_free}> : (tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_WJNUGo/module_000_5tpL_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/zlIsO/src/mlir/IR/Pass.jl:119
ERROR: LoadError: "failed to run pass manager on module"
Stacktrace:
[1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/zlIsO/src/mlir/IR/Pass.jl:163
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1319
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1314 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::var"#compute_grad#2"{var"#loss_fn#1"}, args::Tuple{ConcretePJRTArray{Float32, 1, 1}}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, Tuple{Vararg{Symbol, var"#s1785"}} where var"#s1785", Tuple{Vararg{Int64, N}} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}, sdygroupidcache::Tuple{Reactant.Compiler.SdyGroupIDCounter{Int64}, IdDict{Union{Reactant.TracedRArray, Reactant.TracedRNumber}, Int64}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1767
[5] compile_mlir!
@ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1576 [inlined]
[6] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
@ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3524
[7] compile_xla
@ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3496 [inlined]
[8] compile(f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1}}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, serializable::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
@ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3600
[9] compile
@ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3597 [inlined]
[10] macro expansion
@ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:2669 [inlined]
[11] test_reactant_ka_enzyme()
@ Main /media/jm/hddData/superVoxelJuliaCode_lin_sampl/julia_impl/test_reactant_ka_enzyme.jl:58
[12] top-level scope
@ /media/jm/hddData/superVoxelJuliaCode_lin_sampl/julia_impl/test_reactant_ka_enzyme.jl:71
in expression starting at /media/jm/hddData/superVoxelJuliaCode_lin_sampl/julia_impl/test_reactant_ka_enzyme.jl:71
CUDA v5.9.5
Enzyme v0.13.108
KernelAbstractions v0.9.39
Reactant v0.2.180
julia 1.10.10
@wsmoses , @avikpal would be great if you would have time to look into it .