Error when compiling gather and scatter operation under Reactant

Hi All,

I have started to experiment with Lux, Enzyme, and Reactant. In my applications, I frequently need a operation described as segmented_sum or segmented_mean, which can be computed using compositions of scatter and gather. The current gather does not compile though.

A minimal working example of failure is following, where I have left test that Zygote and Enzyme works, but Reactant fails to compute.


using Zygote
using Enzyme
using Reactant

xx = randn(Float32, 4, 3)
ii = [1, 2, 1, 2, 3]
Zygote.gradient(xx -> sum(NNlib.gather(xx, ii)), xx)
Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii)), xx)

const xdev = reactant_device()
xx_ra = xx |> xdev
ii_ra = sca_ii |> xdev
enzyme_gradient_compiled = @compile Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra)

The error is as follows
error: could not compute the adjoint for this operation %5 = "stablehlo.dynamic_gather"(%3, %4, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<4x3xf32>, tensor<344044xi64>, tensor<2xi64>) -> tensor<4x344044xf32>
ERROR: "failed to run pass manager on module"
Stacktrace:
  [1] run!
    @ ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Pass.jl:79 [inlined]
  [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String; enable_verifier::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:264
  [3] run_pass_pipeline!
    @ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:259 [inlined]
  [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:309
  [5] compile_mlir!
    @ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:290 [inlined]
  [6] (::Reactant.Compiler.var"#34#36"{Bool, typeof(Enzyme.gradient), Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:698
  [7] context!(f::Reactant.Compiler.var"#34#36"{Bool, typeof(Enzyme.gradient), Tuple{…}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Context.jl:76
  [8] compile_xla(f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:695
  [9] compile_xla
    @ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:690 [inlined]
 [10] compile(f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; client::Nothing, optimize::Bool, sync::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:722
 [11] top-level scope
    @ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:475
Some type information was truncated. Use `show(err)` to see complete types.
My working environment is
Status `~/Work/Julia/Avast/SmellyAV/scripts/lux/Project.toml`
  [d360d2e6] ChainRulesCore v1.25.0
  [7da242da] Enzyme v0.13.22
  [b2108857] Lux v1.4.2
  [872c559c] NNlib v0.9.26
  [0b1bfda6] OneHotArrays v0.2.6
  [3c362404] Reactant v0.2.10
  [e88e6eb3] Zygote v0.6.73

and

julia> versioninfo()
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 11 × Apple M3 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m2)
Threads: 1 default, 0 interactive, 1 GC (on 5 virtual cores)

I have to admit that I would like to understand the toolchain, such that I can fix these things by myself.

Thanks in advance for answer.

cc @avikpal

So that basically says that we haven’t implemented the derivative of the dynamic_gather operation yet (which we clearly should) and also weren’t able to optimize it away (perhaps we should, but depending on code).

Can you show the output of @code_hlo optimize=false Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra) and @code_hlo optimize=:before_enzyme Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra) This will print the IR so we can see what’s happening.

Thanks for a swift reply.

The outputs looks like this

@code_hlo optimize=false

julia> @code_hlo optimize=false Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra)
module {
func.func private @“-_broadcast_scalar”(%arg0: tensor, %arg1: tensor) → (tensor, tensor, tensor) {
%0 = stablehlo.transpose %arg0, dims = : (tensor) → tensor
%1 = stablehlo.transpose %arg1, dims = : (tensor) → tensor
%2 = stablehlo.subtract %0, %1 : tensor
%3 = stablehlo.transpose %2, dims = : (tensor) → tensor
%4 = stablehlo.transpose %0, dims = : (tensor) → tensor
%5 = stablehlo.transpose %1, dims = : (tensor) → tensor
return %3, %4, %5 : tensor, tensor, tensor
}
func.func private @identity_broadcast_scalar(%arg0: tensor) → tensor {
%0 = stablehlo.transpose %arg0, dims = : (tensor) → tensor
%1 = stablehlo.transpose %0, dims = : (tensor) → tensor
return %1 : tensor
}
func.func private @“Const{Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}}(Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}(sum \E2\88\98 NNlib.gather, TracedRArray{Int64,1N}(((:args, 2, 1, 2),), size=(5,))))_autodiff”(%arg0: tensor<5xi64>, %arg1: tensor<3x4xf32>) → (tensor, tensor<5xi64>, tensor<3x4xf32>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x4xf32>) → tensor<4x3xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32>
%2 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%c = stablehlo.constant dense<1> : tensor<5xi64>
%3:3 = enzyme.batch @“-_broadcast_scalar”(%2, %c) {batch_shape = array<i64: 5>} : (tensor<5xi64>, tensor<5xi64>) → (tensor<5xi64>, tensor<5xi64>, tensor<5xi64>)
%c_0 = stablehlo.constant dense<[4, 1]> : tensor<2xi64>
%4 = “stablehlo.dynamic_gather”(%1, %3#0, %c_0) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<4x3xf32>, tensor<5xi64>, tensor<2xi64>) → tensor<4x5xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor
%5 = stablehlo.broadcast_in_dim %4, dims = [0, 1] : (tensor<4x5xf32>) → tensor<4x5xf32>
%6 = enzyme.batch @identity_broadcast_scalar(%5) {batch_shape = array<i64: 4, 5>} : (tensor<4x5xf32>) → tensor<4x5xf32>
%7 = stablehlo.reduce(%6 init: %cst_1) applies stablehlo.add across dimensions = [0, 1] : (tensor<4x5xf32>, tensor) → tensor
%8 = stablehlo.transpose %7, dims = : (tensor) → tensor
%9 = stablehlo.transpose %0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%10 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
return %8, %9, %10 : tensor, tensor<5xi64>, tensor<3x4xf32>
}
func.func @main(%arg0: tensor<5xi64>, %arg1: tensor<3x4xf32>) → (tensor<3x4xf32>, tensor<5xi64>, tensor<3x4xf32>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x4xf32>) → tensor<4x3xf32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<4x3xf32>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor
%2 = stablehlo.transpose %0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%3 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
%4 = stablehlo.transpose %cst_0, dims = : (tensor) → tensor
%5 = stablehlo.transpose %cst, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
%6:3 = enzyme.autodiff @“Const{Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}}(Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}(sum \E2\88\98 NNlib.gather, TracedRArray{Int64,1N}(((:args, 2, 1, 2),), size=(5,))))_autodiff”(%2, %3, %4, %5) {activity = [enzyme, enzyme], ret_activity = [enzyme, enzyme, enzyme]} : (tensor<5xi64>, tensor<3x4xf32>, tensor, tensor<3x4xf32>) → (tensor<5xi64>, tensor<3x4xf32>, tensor<3x4xf32>)
%7 = stablehlo.transpose %6#0, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%8 = stablehlo.transpose %6#1, dims = [1, 0] : (tensor<3x4xf32>) → tensor<4x3xf32>
%9 = stablehlo.transpose %6#2, dims = [1, 0] : (tensor<3x4xf32>) → tensor<4x3xf32>
%10 = stablehlo.transpose %9, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
%11 = stablehlo.transpose %7, dims = [0] : (tensor<5xi64>) → tensor<5xi64>
%12 = stablehlo.transpose %8, dims = [1, 0] : (tensor<4x3xf32>) → tensor<3x4xf32>
return %10, %11, %12 : tensor<3x4xf32>, tensor<5xi64>, tensor<3x4xf32>
}
}

@code_hlo optimize=:before_enzyme

julia> @code_hlo optimize=:before_enzyme Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra)
error: could not compute the adjoint for this operation %5 = “stablehlo.dynamic_gather”(%3, %4, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<4x3xf32>, tensor<5xi64>, tensor<2xi64>) → tensor<4x5xf32>
ERROR: “failed to run pass manager on module”
Stacktrace:
[1] run!
@ ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Pass.jl:79 [inlined]
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:264
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:259 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Symbol)
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:349
[5] compile_mlir!
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:290 [inlined]
[6] #6
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:285 [inlined]
[7] context!(f::Reactant.Compiler.var"#6#7"{@Kwargs{…}, typeof(Enzyme.gradient), Tuple{…}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Context.jl:76
[8] compile_mlir(f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; kwargs::@Kwargs{optimize::Symbol})
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:283
[9] top-level scope
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:475
Some type information was truncated. Use show(err) to see complete types.

I would like to understand this more. Is there any material you would recommend for study?

With the new JLL if we don’t run the enzyme pass we get:

module {
  func.func private @"Const{Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}}(Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}(sum \E2\88\98 NNlib.gather, TracedRArray{Int64,1N}(((:args, 2, 1, 2),), size=(5,))))_autodiff"(%arg0: tensor<5xi64>, %arg1: tensor<3x4xf32>) -> (tensor<f32>, tensor<5xi64>, tensor<3x4xf32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %c = stablehlo.constant dense<1> : tensor<5xi64>
    %0 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
    %1 = stablehlo.subtract %arg0, %c : tensor<5xi64>
    %2 = "stablehlo.gather"(%0, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>, slice_sizes = array<i64: 4, 1>}> : (tensor<4x3xf32>, tensor<5xi64>) -> tensor<4x5xf32>
    %3 = stablehlo.reduce(%2 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<4x5xf32>, tensor<f32>) -> tensor<f32>
    return %3, %arg0, %arg1 : tensor<f32>, tensor<5xi64>, tensor<3x4xf32>
  }
  func.func @main(%arg0: tensor<5xi64>, %arg1: tensor<3x4xf32>) -> (tensor<3x4xf32>, tensor<5xi64>, tensor<3x4xf32>) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<3x4xf32>
    %0:3 = enzyme.autodiff @"Const{Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}}(Base.Fix2{ComposedFunction{typeof(sum), typeof(NNlib.gather)}, Reactant.TracedRArray{Int64, 1}}(sum \E2\88\98 NNlib.gather, TracedRArray{Int64,1N}(((:args, 2, 1, 2),), size=(5,))))_autodiff"(%arg0, %arg1, %cst, %cst_0) {activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>]} : (tensor<5xi64>, tensor<3x4xf32>, tensor<f32>, tensor<3x4xf32>) -> (tensor<5xi64>, tensor<3x4xf32>, tensor<3x4xf32>)
    return %0#2, %0#0, %0#1 : tensor<3x4xf32>, tensor<5xi64>, tensor<3x4xf32>
  }
}

I realized we don’t expose an option to not run just the enzyme passes which is honestly one of the most useful ones to disable for generating minimal failure cases

It’s a lil bit hard to explain everything because there are currently a lot of changes and refactors ongoing, and XLA/MLIR are a whole topic on their own, but in a summary Reactant.jl…

  1. Translates ConcreteRArrays for TracedRArrays
  2. Traces your code by running with the TracedRArrays which we have specialized the typical methods to emit MLIR operations
    • This operator tracing have some niceties included thanks to a custom abstract interpreter, which allows us to perform method replacement (e.g. we do this to replace Enzyme.autodiff to call Enzyme on MLIR instead of LLVM IR) and also trace over control-flow
  3. Run a bunch of high-level optimizations like batching, simplifications and also, the autodiff from Enzyme
  4. Compile it with XLA

That @code_hlo just performs steps 1 and 2 (and 3 is controlled with the optimize kwarg).

Reactant is not yet ready for production, it’s still veeery experimental, but we are going surprisingly fast with development and we plan to publish a paper in a matter of months where we will better explain all this.

Meanwhile, I would recommend you to start learning MLIR and how JAX performs operator tracing.

1 Like

Thanks a lot for intuition. Hopefully xmas will give me some time to learn about these. Your answer gives me good hint, but I do not understand yet the nuts and bolts of the implementation.

Yeah, we have prioritized functionality right now and from time to time, I take all code and clean and refactor to keep it a lil bit readable and coherent. But development goes fast, we break a lot of things, … I promise we will clean it.