Type instability in custom Lux masking layer when using Enzyme/Reactant

I’m encountering issues when implementing a custom Lux layer that performs masking (as commonly used in RealNVPs).

Example in 1D with checkerboard masking:

[x1, x2, x3, x4] → ([x1,0,x3,0], [0,x2,0,x4])

Goal

I want to implement this by defining a Boolean mask and multiplying elementwise:

x_in = randn(2,10) # batch dimension = 10
mask = [true, false]

x_even = mask .* x_in
x_odd  = .!mask .* x_in  # or (1 .- mask) .* x_in

Custom layer

struct EvenOddMaskingLayer{M<:AbstractArray} <: AbstractLuxLayer
    mask::M
end

function (m::EvenOddMaskingLayer)(x, ps, st)
    x_even = ifelse.(m.mask, x, zero(x))
    x_odd  = ifelse.(m.mask, zero(x), x)
    return (x_even, x_odd), st
end

Chained with other layers:

model = Chain(
    EvenOddMaskingLayer([true,false]),
    x -> vcat(x...),      # stack even/odd parts
    Dense(4 => 2, gelu)
)

With AutoZygote() everything works fine, but when using Reactant + AutoEnzyme, single_train_step! fails with a type instability error.

Minimal setup:

ps, st = Lux.setup(Xoshiro(1994), model)
const xdev = reactant_device()
ϕ = randn(Float32,2,10) |> xdev

function _loss(m,p,s,x)
    y, st = m(x,p,s)
    return sum(y .- ϕ), st, (;)
end

tstate = Training.TrainState(model, ps, st, Optimisers.Adam(0.004f0))
Training.single_train_step!(AutoEnzyme(), _loss, ϕ, tstate)

Observations

  • Works with AutoZygote()
  • Fails with AutoEnzyme() / Reactant
  • DispatchDoctor warns of a TypeInstabilityError involving ifelse and the mask argument.
  • Using EvenOddMaskingLayer(xdev([true,false])) doesn’t help.
  • Other implementations of the apply like x_even = maskl.mask .* x_in don’t work.

Question:
Is there a type-stable way to implement Boolean masking in a Lux layer that’s compatible with Enzyme/Reactant?
Any suggested workaround for making this ifelse broadcast Enzyme-friendly would be appreciated.

Full error

┌ Warning: DispatchDoctor.TypeInstabilityWarning: Instability detected in `apply` defined at /Users/pietro/.julia/packages/LuxCore/VYmys/src/LuxCore.jl:154 with arguments `(Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{})`. Inferred to be `Tuple{AbstractMatrix, @NamedTuple{}}`, which is not a concrete type.
└ @ DispatchDoctor._Stabilization ~/.julia/packages/DispatchDoctor/qgDil/src/stabilization.jl:185
ERROR: Enzyme compilation failed due to illegal type analysis.
 This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].
 Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
 Failure within method: MethodInstance for copy(::Base.Broadcast.Broadcasted{Base.Broadcast.ArrayStyle{…}, Tuple{…}, typeof(ifelse), Tuple{…}})
Hint: catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.
If you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.
Caused by:
Stacktrace:
 [1] getindex
   @ ./tuple.jl:31
 [2] iterate (repeats 2 times)
   @ ./tuple.jl:71
 [3] copy
   @ ~/.julia/packages/Reactant/lu2GU/src/ConcreteRArray.jl:462

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/errors.jl:474
  [2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/errors.jl:348
  [3] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, runtimeActivity::Bool, strongZero::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/nV4l9/src/api.jl:418
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:2423
  [5] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:4949
  [6] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/bTNLD/src/driver.jl:67
  [7] compile
    @ ~/.julia/packages/GPUCompiler/bTNLD/src/driver.jl:55 [inlined]
  [8] _thunk(job::GPUCompiler.CompilerJob{…}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:5803
  [9] _thunk
    @ ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:5801 [inlined]
 [10] cached_compilation
    @ ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:5858 [inlined]
 [11] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:5974
 [12] thunk_generator(world::UInt64, source::Union{…}, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type, strongzero::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:6169
 [13] autodiff_thunk
    @ ~/.julia/packages/Enzyme/nV4l9/src/Enzyme.jl:997 [inlined]
 [14] autodiff
    @ ~/.julia/packages/Enzyme/nV4l9/src/Enzyme.jl:391 [inlined]
 [15] compute_gradients_impl(ad::AutoEnzyme{…}, obj_fn::typeof(_loss), data::ConcretePJRTArray{…}, ts::Lux.Training.TrainState{…})
    @ LuxEnzymeExt ~/.julia/packages/Lux/wjsb3/ext/LuxEnzymeExt/training.jl:17
 [16] compute_gradients
    @ ~/.julia/packages/Lux/wjsb3/src/helpers/training.jl:205 [inlined]
 [17] single_train_step_impl!(backend::AutoEnzyme{…}, obj_fn::typeof(_loss), data::ConcretePJRTArray{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/wjsb3/src/helpers/training.jl:331
 [18] #single_train_step!#6
    @ ~/.julia/packages/Lux/wjsb3/src/helpers/training.jl:297 [inlined]
 [19] single_train_step!(backend::AutoEnzyme{…}, obj_fn::typeof(_loss), data::ConcretePJRTArray{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/wjsb3/src/helpers/training.jl:291
 [20] macro expansion
    @ ./timing.jl:581 [inlined]
 [21] top-level scope
    @ ~/code/software/Flows/enzyme_masking_layer.jl:38
Some type information was truncated. Use `show(err)` to see complete types.

Minimal code to riproduce the error

using Lux, Optimisers, Random
using Enzyme, Reactant
const xdev = reactant_device()


struct EvenOddMaskingLayer{M<:AbstractArray} <: AbstractLuxLayer
    mask::M
end

function (maskl::EvenOddMaskingLayer)(x_in,ps,st)
    x_even = ifelse.(maskl.mask, x_in, zero(x_in))
    x_odd = ifelse.(maskl.mask, zero(x_in), x_in)
    return (x_even, x_odd), st
end


rng = Xoshiro(1994)
mod = Chain(
    EvenOddMaskingLayer([true,false]),
    x->vcat(x...), # stack even and odd
    Dense(4=>2,gelu)
)

ps, st = Lux.setup(rng, mod)

ϕ = randn(Float32,2,10) |> xdev
y,_ = mod(ϕ,ps,st)


function _loss(m,p,s,x) 
    y,st = m(x,p,s)
    return sum(y .- x), st, (;)
end

tstate = Training.TrainState(mod, ps, st, Optimisers.Adam(0.004f0))
Training.single_train_step!(AutoEnzyme(),_loss, ϕ,tstate)

Long story short, Reactant doesn’t care about type-stability (I updated the performance docs for Lux earlier today).

Also note, you are using Reactant incorrectly. You need to @compile the single_train_step! call, and then call the compiled function.

Thanks for the quick reply!

I wasn’t aware that Reactant doesn’t enforce type stability, and that @compile is required before calling single_train_step!. Also, from the documentation Compiling Lux Models using Reactant.jl | Lux.jl Docs (from “Using TrainState API section”) and from Normalizing Flows for Density Estimation | Lux.jl Docs, it isn’t entirely clear that @compile needs to be called explicitly in this context.

That said, even after calling @compile Training.single_train_step!(...), I still encounter the same error. I managed to get a fully compilable version by passing ps and st through xdev and compiling the model directly:

ps, st = Lux.setup(rng, mod) |> xdev

ϕ = randn(Float32, 2, 10) |> xdev
y, _ = mod(ϕ, ps, st)

mod_compiled = @compile mod(ϕ, ps, Lux.testmode(st))
mod_compiled(ϕ, ps, Lux.testmode(st))

Is this the correct approach?

Would it be possible to outline the minimal requirements or best practices when using AutoEnzyme() for those who are still getting familiar with Reactant?

Thanks again for your help and for maintaining such a great library!

1 Like

Oops yeah my bad (shouldn’t respond questions in the morning without coffee :sweat_smile:) , I did not look at the code properly, @compile is not needed for single_train_step! but only for the model forward pass.

You original code should just work with the following changes:

--- ../Lux.jl/envs/reactant/test.jl	2025-10-10 13:06:37.612543325 -0400
+++ ../Lux.jl/envs/reactant/test2.jl	2025-10-10 13:08:49.101224655 -0400
@@ -19,10 +19,11 @@
     Dense(4 => 2, gelu),
 )
 
-ps, st = Lux.setup(rng, mod)
+ps, st = Lux.setup(rng, mod) |> xdev
 
 ϕ = randn(Float32, 2, 10) |> xdev
-y, _ = mod(ϕ, ps, st)
+compiled_mod = @compile mod(ϕ, ps, st)
+y, _ = compiled_mod(ϕ, ps, st)
 
 function _loss(m, p, s, x)
     y, st = m(x, p, s)

Would it be possible to outline the minimal requirements or best practices when using AutoEnzyme() for those who are still getting familiar with Reactant?

I have updated the performance section (will take some time for the docs build to get through), and marked the sections that no longer apply to Reactant. Happy to make additional changes

1 Like

Thank you very much, this is very helpful! I look forward to read the new documentation part.