Hi All,
I have started to experiment with Lux, Enzyme, and Reactant. In my applications, I frequently need a operation described as segmented_sum
or segmented_mean
, which can be computed using compositions of scatter
and gather
. The current gather
does not compile though.
A minimal working example of failure is following, where I have left test that Zygote and Enzyme works, but Reactant fails to compute.
using Zygote
using Enzyme
using Reactant
xx = randn(Float32, 4, 3)
ii = [1, 2, 1, 2, 3]
Zygote.gradient(xx -> sum(NNlib.gather(xx, ii)), xx)
Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii)), xx)
const xdev = reactant_device()
xx_ra = xx |> xdev
ii_ra = sca_ii |> xdev
enzyme_gradient_compiled = @compile Enzyme.gradient(Enzyme.Reverse, Const(Base.Fix2(sum ∘ NNlib.gather, ii_ra)), xx_ra)
The error is as follows
error: could not compute the adjoint for this operation %5 = "stablehlo.dynamic_gather"(%3, %4, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1], index_vector_dim = 1>}> : (tensor<4x3xf32>, tensor<344044xi64>, tensor<2xi64>) -> tensor<4x344044xf32>
ERROR: "failed to run pass manager on module"
Stacktrace:
[1] run!
@ ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Pass.jl:79 [inlined]
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:264
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:259 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:309
[5] compile_mlir!
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:290 [inlined]
[6] (::Reactant.Compiler.var"#34#36"{Bool, typeof(Enzyme.gradient), Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}})()
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:698
[7] context!(f::Reactant.Compiler.var"#34#36"{Bool, typeof(Enzyme.gradient), Tuple{…}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/sIJRJ/src/mlir/IR/Context.jl:76
[8] compile_xla(f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; client::Nothing, optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:695
[9] compile_xla
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:690 [inlined]
[10] compile(f::Function, args::Tuple{ReverseMode{…}, Const{…}, ConcreteRArray{…}}; client::Nothing, optimize::Bool, sync::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:722
[11] top-level scope
@ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:475
Some type information was truncated. Use `show(err)` to see complete types.
My working environment is
Status `~/Work/Julia/Avast/SmellyAV/scripts/lux/Project.toml`
[d360d2e6] ChainRulesCore v1.25.0
[7da242da] Enzyme v0.13.22
[b2108857] Lux v1.4.2
[872c559c] NNlib v0.9.26
[0b1bfda6] OneHotArrays v0.2.6
[3c362404] Reactant v0.2.10
[e88e6eb3] Zygote v0.6.73
and
julia> versioninfo()
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (arm64-apple-darwin24.0.0)
CPU: 11 × Apple M3 Pro
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, apple-m2)
Threads: 1 default, 0 interactive, 1 GC (on 5 virtual cores)
I have to admit that I would like to understand the toolchain, such that I can fix these things by myself.
Thanks in advance for answer.