KernelAbstractions + CUDA + Reactant - how to get minimal working example

Hello I tried to get minimal working example for reactant powered kernel abstractions plus Enzyme like

using Pkg
Pkg.activate(".")
using Reactant
using KernelAbstractions
using CUDA
using Enzyme
using Test

# Simple square kernel
@kernel function square_kernel!(y, @Const(x))
    i = @index(Global)
    @inbounds y[i] = x[i] * x[i]
end

function square(x)
    y = similar(x)
    backend = KernelAbstractions.get_backend(x)
    kernel! = square_kernel!(backend)
    kernel!(y, x; ndrange=length(x))
    return y
end

function test_reactant_ka_enzyme()
    Reactant.set_default_backend("gpu")
    
    x = Reactant.to_rarray(collect(Float32, 1:10))
    
    # Forward pass test
    println("Testing forward pass...")
    y_compiled = @compile square(x)
    println("Forward pass result: ", y_compiled)
    
    # Gradient test
    println("Testing gradient...")
    dz = Reactant.to_rarray(ones(Float32, 10))
    
    # Define a scalar loss function for gradient testing
    function loss_fn(x)
        y = square(x)
        return sum(y)
    end
    
    # Compile gradient
    println("Compiling gradient...")
    # Reactant @compile expects a function call. We can compile a wrapper or use it directly.
    # Correct usage: @compile Enzyme.gradient(Active, loss_fn, x)
    # Note: @compile returns a compiled function, which we then call.
    # However, Enzyme.gradient returns a tuple.
    
    # Let's define a gradient wrapper to be compiled
    function compute_grad(x)
        # Use Enzyme.Reverse mode explicitly
        # And since we are differentiating a function 'loss_fn' which is a closure,
        # we might need to be careful. But let's try standard API first.
        return Enzyme.gradient(Enzyme.Reverse, loss_fn, x)[1]
    end
    
    compiled_grad_fn = @compile compute_grad(x)
    dx = compiled_grad_fn(x)
    
    println("Gradient result: ", dx)
    
    # Verification
    expected_dx = 2 .* collect(Float32, 1:10)
    println("Expected gradient: ", expected_dx)
    
    @test all(Array(dx) .≈ expected_dx)
    println("Test Passed!")
end

test_reactant_ka_enzyme()

it give error

****************************************************************************
* hwloc 2.0.3 received invalid information from the operating system.
*
* Failed with: intersection without inclusion
* while inserting Group0 (P#16 cpuset 0x00000f00) at L1d (P#16 cpuset 0x00000104)
* coming from: linux:sysfs:cluster
*
* The following FAQ entry in the hwloc documentation may help:
*   What should I do when hwloc reports "operating system" warnings?
* Otherwise please report this error message to the hwloc user's mailing list,
* along with the files generated by the hwloc-gather-topology script.
* 
* hwloc will now ignore this invalid topology information and continue.
****************************************************************************
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1764257009.342895 1121843 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0000 00:00:1764257009.345656 1121842 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
I0000 00:00:1764257009.345859 1121616 service.cc:158] XLA service 0x3785a420 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1764257009.345874 1121616 service.cc:166]   StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
I0000 00:00:1764257009.345877 1121616 service.cc:166]   StreamExecutor device (1): NVIDIA GeForce RTX 3090, Compute Capability 8.6
I0000 00:00:1764257009.346437 1121616 se_gpu_pjrt_client.cc:1039] Using BFC allocator.
I0000 00:00:1764257009.346460 1121616 gpu_helpers.cc:136] XLA backend allocating 18965004288 bytes on device 0 for BFCAllocator.
I0000 00:00:1764257009.346488 1121616 gpu_helpers.cc:136] XLA backend allocating 18972033024 bytes on device 1 for BFCAllocator.
I0000 00:00:1764257009.346494 1121616 gpu_helpers.cc:177] XLA backend will use up to 6321668096 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1764257009.346500 1121616 gpu_helpers.cc:177] XLA backend will use up to 6324011008 bytes on device 1 for CollectiveBFCAllocator.
W0000 00:00:1764257009.347720 1121616 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0000 00:00:1764257009.351167 1121616 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
I0000 00:00:1764257009.355886 1121616 cuda_dnn.cc:463] Loaded cuDNN version 91400
Testing forward pass...
'86' is not a recognized feature for this target (ignoring feature)
'86' is not a recognized feature for this target (ignoring feature)
'86' is not a recognized feature for this target (ignoring feature)
Forward pass result: Reactant compiled function square (with tag ##square_reactant#237)
Testing gradient...
Compiling gradient...
error: could not compute the adjoint for this operation %5 = "enzymexla.kernel_call"(%3, %3, %3, %4, %3, %3, %1, %2, %arg0) <{backend_config = "", fn = @"##call__Z18gpu_square_kernel_16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_S0_S8_S8_EE13CuTracedArrayI7Float32Li1ELi1E5_10__ESE_#242", operandSegmentSizes = array<i32: 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 2>, output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>], xla_side_effect_free}> : (tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_WJNUGo/module_000_5tpL_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/zlIsO/src/mlir/IR/Pass.jl:119
ERROR: LoadError: "failed to run pass manager on module"
Stacktrace:
  [1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/zlIsO/src/mlir/IR/Pass.jl:163
  [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1319
  [3] run_pass_pipeline!
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1314 [inlined]
  [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::var"#compute_grad#2"{var"#loss_fn#1"}, args::Tuple{ConcretePJRTArray{Float32, 1, 1}}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, Tuple{Vararg{Symbol, var"#s1785"}} where var"#s1785", Tuple{Vararg{Int64, N}} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}, sdygroupidcache::Tuple{Reactant.Compiler.SdyGroupIDCounter{Int64}, IdDict{Union{Reactant.TracedRArray, Reactant.TracedRNumber}, Int64}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1767
  [5] compile_mlir!
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:1576 [inlined]
  [6] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3524
  [7] compile_xla
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3496 [inlined]
  [8] compile(f::Function, args::Tuple{ConcretePJRTArray{Float32, 1, 1}}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, serializable::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
    @ Reactant.Compiler ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3600
  [9] compile
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:3597 [inlined]
 [10] macro expansion
    @ ~/.julia/packages/Reactant/zlIsO/src/Compiler.jl:2669 [inlined]
 [11] test_reactant_ka_enzyme()
    @ Main /media/jm/hddData/superVoxelJuliaCode_lin_sampl/julia_impl/test_reactant_ka_enzyme.jl:58
 [12] top-level scope
    @ /media/jm/hddData/superVoxelJuliaCode_lin_sampl/julia_impl/test_reactant_ka_enzyme.jl:71
in expression starting at /media/jm/hddData/superVoxelJuliaCode_lin_sampl/julia_impl/test_reactant_ka_enzyme.jl:71

CUDA v5.9.5
Enzyme v0.13.108
KernelAbstractions v0.9.39
Reactant v0.2.180
julia 1.10.10

@wsmoses , @avikpal would be great if you would have time to look into it .

For now EnzymeXLA won’t try to generate the rule meaning you must raise the kernel ie tell Reactant to search for some combinations of stablehlo calls that ends up doing the same, for single run it means @compile raise=true fun(x) and for gradient you need to raise the primal runs ie @compile raise=true raise_first=true gradfun(x).
Ps : this also means it’s not really your kernel being run anymore but some specialized kernel that ends up doing the same and thanks to fusion and auto threading parameters may be as good as yours.

Feels weird it’s still not documented the only thing mentioning it is in the doc of compiling in Core Reactant API | Reactant.jl

2 Likes