Error when compiling gather and scatter operation under Reactant

Hi All,

I have started to experiment with Lux, Enzyme, and Reactant. In my applications, I frequently need a operation described as segmented_sum or segmented_mean, which can be computed using compositions of scatter and gather. The current gather does not compile though.

A minimal working example of failure is following, where I have left test that Zygote and Enzyme works, but Reactant fails to compute.

using Zygote
using Enzyme
using Reactant

xx = randn(Float32, 4, 3)
ii = [1, 2, 1, 2, 3]
Zygote.gradient(xx -> sum(NNlib.gather(xx, ii)), xx)
Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii)), xx)

const xdev = reactant_device()
xx_ra = xx |> xdev
ii_ra = sca_ii |> xdev
enzyme_gradient_compiled = @compile Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra)

The error is as follows
error: could not compute the adjoint for this operation %5 = "stablehlo.dynamic_gather"(%3, %4, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<4x3xf32>, tensor<344044xi64>, tensor<2xi64>) -> tensor<4x344044xf32>
ERROR: "failed to run pass manager on module"
  [1] run!
    @ ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Pass.jl:79 [inlined]
  [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String; enable_verifier::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:264
  [3] run_pass_pipeline!
    @ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:259 [inlined]
  [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:309
  [5] compile_mlir!
    @ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:290 [inlined]
  [6] (::Reactant.Compiler.var"#34#36"{Bool, typeof(Enzyme.gradient), Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:698
  [7] context!(f::Reactant.Compiler.var"#34#36"{Bool, typeof(Enzyme.gradient), Tuple{…}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Context.jl:76
  [8] compile_xla(f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:695
  [9] compile_xla
    @ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:690 [inlined]
 [10] compile(f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; client::Nothing, optimize::Bool, sync::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:722
 [11] top-level scope
    @ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:475
Some type information was truncated. Use `show(err)` to see complete types.
My working environment is
Status `~/Work/Julia/Avast/SmellyAV/scripts/lux/Project.toml`
  [d360d2e6] ChainRulesCore v1.25.0
  [7da242da] Enzyme v0.13.22
  [b2108857] Lux v1.4.2
  [872c559c] NNlib v0.9.26
  [0b1bfda6] OneHotArrays v0.2.6
  [3c362404] Reactant v0.2.10
  [e88e6eb3] Zygote v0.6.73


julia> versioninfo()
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 11 × Apple M3 Pro
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m2)
Threads: 1 default, 0 interactive, 1 GC (on 5 virtual cores)

I have to admit that I would like to understand the toolchain, such that I can fix these things by myself.

Thanks in advance for answer.

cc @avikpal

So that basically says that we haven’t implemented the derivative of the dynamic_gather operation yet (which we clearly should) and also weren’t able to optimize it away (perhaps we should, but depending on code).

Can you show the output of @code_hlo optimize=false Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra) and @code_hlo optimize=:before_enzyme Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra) This will print the IR so we can see what’s happening.

Thanks for a swift reply.

The outputs looks like this

@code_hlo optimize=false

julia> @code_hlo optimize=false Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra)
module {
func.func private @“-_broadcast_scalar”(%arg0: tensor, %arg1: tensor) → (tensor, tensor, tensor) {
%0 = stablehlo.transpose %arg0, dims = : (tensor) → tensor
%1 = stablehlo.transpose %arg1, dims = : (tensor) → tensor
%2 = stablehlo.subtract %0, %1 : tensor
%3 = stablehlo.transpose %2, dims = : (tensor) → tensor
%4 = stablehlo.transpose %0, dims = : (tensor) → tensor
%5 = stablehlo.transpose %1, dims = : (tensor) → tensor
return %3, %4, %5 : tensor, tensor, tensor
func.func private @identity_broadcast_scalar(%arg0: tensor) → tensor {
%0 = stablehlo.transpose %arg0, dims = : (tensor) → tensor
%1 = stablehlo.transpose %0, dims = : (tensor) → tensor
return %1 : tensor
func.func private @“Const{Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}}(Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}(sum \E2\88\98 NNlib.gather, TracedRArray{Int64,1N}(((:args, 2, 1, 2),), size=(5,))))_autodiff”(%arg0: tensor<5xi64>, %arg1: tensor<3x4xf32>) → (tensor, tensor<5xi64>, tensor<3x4xf32>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x4xf32>) → tensor<4x3xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32>
%2 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%c = stablehlo.constant dense<1> : tensor<5xi64>
%3:3 = enzyme.batch @“-_broadcast_scalar”(%2, %c) {batch_shape = array<i64: 5>} : (tensor<5xi64>, tensor<5xi64>) → (tensor<5xi64>, tensor<5xi64>, tensor<5xi64>)
%c_0 = stablehlo.constant dense<[4, 1]> : tensor<2xi64>
%4 = “stablehlo.dynamic_gather”(%1, %3#0, %c_0) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<4x3xf32>, tensor<5xi64>, tensor<2xi64>) → tensor<4x5xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor
%5 = stablehlo.broadcast_in_dim %4, dims = [0, 1] : (tensor<4x5xf32>) → tensor<4x5xf32>
%6 = enzyme.batch @identity_broadcast_scalar(%5) {batch_shape = array<i64: 4, 5>} : (tensor<4x5xf32>) → tensor<4x5xf32>
%7 = stablehlo.reduce(%6 init: %cst_1) applies stablehlo.add across dimensions = [0, 1] : (tensor<4x5xf32>, tensor) → tensor
%8 = stablehlo.transpose %7, dims = : (tensor) → tensor
%9 = stablehlo.transpose %0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%10 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
return %8, %9, %10 : tensor, tensor<5xi64>, tensor<3x4xf32>
func.func @main(%arg0: tensor<5xi64>, %arg1: tensor<3x4xf32>) → (tensor<3x4xf32>, tensor<5xi64>, tensor<3x4xf32>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x4xf32>) → tensor<4x3xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<4x3xf32>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor
%2 = stablehlo.transpose %0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%3 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
%4 = stablehlo.transpose %cst_0, dims = : (tensor) → tensor
%5 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
%6:3 = enzyme.autodiff @“Const{Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}}(Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}(sum \E2\88\98 NNlib.gather, TracedRArray{Int64,1N}(((:args, 2, 1, 2),), size=(5,))))_autodiff”(%2, %3, %4, %5) {activity = [enzyme, enzyme], ret_activity = [enzyme, enzyme, enzyme]} : (tensor<5xi64>, tensor<3x4xf32>, tensor, tensor<3x4xf32>) → (tensor<5xi64>, tensor<3x4xf32>, tensor<3x4xf32>)
%7 = stablehlo.transpose %6#0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%8 = stablehlo.transpose %6#1, dims = [1, 0] : (tensor<3x4xf32>) → tensor<4x3xf32>
%9 = stablehlo.transpose %6#2, dims = [1, 0] : (tensor<3x4xf32>) → tensor<4x3xf32>
%10 = stablehlo.transpose %9, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
%11 = stablehlo.transpose %7, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%12 = stablehlo.transpose %8, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
return %10, %11, %12 : tensor<3x4xf32>, tensor<5xi64>, tensor<3x4xf32>

@code_hlo optimize=:before_enzyme

julia> @code_hlo optimize=:before_enzyme Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra)
error: could not compute the adjoint for this operation %5 = “stablehlo.dynamic_gather”(%3, %4, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<4x3xf32>, tensor<5xi64>, tensor<2xi64>) → tensor<4x5xf32>
ERROR: “failed to run pass manager on module”
[1] run!
@ ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Pass.jl:79 [inlined]
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:264
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:259 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Symbol)
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:349
[5] compile_mlir!
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:290 [inlined]
[6] #6
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:285 [inlined]
[7] context!(f::Reactant.Compiler.var"#6#7"{@Kwargs{…}, typeof(Enzyme.gradient), Tuple{…}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Context.jl:76
[8] compile_mlir(f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; kwargs::@Kwargs{optimize::Symbol})
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:283
[9] top-level scope
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:475
Some type information was truncated. Use show(err) to see complete types.

I would like to understand this more. Is there any material you would recommend for study?