I’m encountering issues when implementing a custom Lux layer that performs masking (as commonly used in RealNVPs).
Example in 1D with checkerboard masking:
[x1, x2, x3, x4] → ([x1,0,x3,0], [0,x2,0,x4])
Goal
I want to implement this by defining a Boolean mask and multiplying elementwise:
x_in = randn(2,10) # batch dimension = 10
mask = [true, false]
x_even = mask .* x_in
x_odd = .!mask .* x_in # or (1 .- mask) .* x_in
Custom layer
struct EvenOddMaskingLayer{M<:AbstractArray} <: AbstractLuxLayer
mask::M
end
function (m::EvenOddMaskingLayer)(x, ps, st)
x_even = ifelse.(m.mask, x, zero(x))
x_odd = ifelse.(m.mask, zero(x), x)
return (x_even, x_odd), st
end
Chained with other layers:
model = Chain(
EvenOddMaskingLayer([true,false]),
x -> vcat(x...), # stack even/odd parts
Dense(4 => 2, gelu)
)
With AutoZygote()
everything works fine, but when using Reactant
+ AutoEnzyme
, single_train_step!
fails with a type instability error.
Minimal setup:
ps, st = Lux.setup(Xoshiro(1994), model)
const xdev = reactant_device()
ϕ = randn(Float32,2,10) |> xdev
function _loss(m,p,s,x)
y, st = m(x,p,s)
return sum(y .- ϕ), st, (;)
end
tstate = Training.TrainState(model, ps, st, Optimisers.Adam(0.004f0))
Training.single_train_step!(AutoEnzyme(), _loss, ϕ, tstate)
Observations
- Works with
AutoZygote()
- Fails with
AutoEnzyme()
/Reactant
- DispatchDoctor warns of a
TypeInstabilityError
involvingifelse
and themask
argument. - Using
EvenOddMaskingLayer(xdev([true,false]))
doesn’t help. - Other implementations of the
apply
likex_even = maskl.mask .* x_in
don’t work.
Question:
Is there a type-stable way to implement Boolean masking in a Lux layer that’s compatible with Enzyme/Reactant?
Any suggested workaround for making this ifelse
broadcast Enzyme-friendly would be appreciated.
Full error
┌ Warning: DispatchDoctor.TypeInstabilityWarning: Instability detected in `apply` defined at /Users/pietro/.julia/packages/LuxCore/VYmys/src/LuxCore.jl:154 with arguments `(Dense{typeof(NNlib.gelu_tanh), Int64, Int64, Nothing, Nothing, Static.True}, ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, @NamedTuple{})`. Inferred to be `Tuple{AbstractMatrix, @NamedTuple{}}`, which is not a concrete type.
└ @ DispatchDoctor._Stabilization ~/.julia/packages/DispatchDoctor/qgDil/src/stabilization.jl:185
ERROR: Enzyme compilation failed due to illegal type analysis.
This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].
Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.
To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Failure within method: MethodInstance for copy(::Base.Broadcast.Broadcasted{Base.Broadcast.ArrayStyle{…}, Tuple{…}, typeof(ifelse), Tuple{…}})
Hint: catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.
If you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.
Caused by:
Stacktrace:
[1] getindex
@ ./tuple.jl:31
[2] iterate (repeats 2 times)
@ ./tuple.jl:71
[3] copy
@ ~/.julia/packages/Reactant/lu2GU/src/ConcreteRArray.jl:462
Stacktrace:
[1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/errors.jl:474
[2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/errors.jl:348
[3] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, runtimeActivity::Bool, strongZero::Bool, width::Int64, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/nV4l9/src/api.jl:418
[4] enzyme!(job::GPUCompiler.CompilerJob{…}, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:2423
[5] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:4949
[6] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/packages/GPUCompiler/bTNLD/src/driver.jl:67
[7] compile
@ ~/.julia/packages/GPUCompiler/bTNLD/src/driver.jl:55 [inlined]
[8] _thunk(job::GPUCompiler.CompilerJob{…}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:5803
[9] _thunk
@ ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:5801 [inlined]
[10] cached_compilation
@ ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:5858 [inlined]
[11] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, edges::Vector{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:5974
[12] thunk_generator(world::UInt64, source::Union{…}, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type, strongzero::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/nV4l9/src/compiler.jl:6169
[13] autodiff_thunk
@ ~/.julia/packages/Enzyme/nV4l9/src/Enzyme.jl:997 [inlined]
[14] autodiff
@ ~/.julia/packages/Enzyme/nV4l9/src/Enzyme.jl:391 [inlined]
[15] compute_gradients_impl(ad::AutoEnzyme{…}, obj_fn::typeof(_loss), data::ConcretePJRTArray{…}, ts::Lux.Training.TrainState{…})
@ LuxEnzymeExt ~/.julia/packages/Lux/wjsb3/ext/LuxEnzymeExt/training.jl:17
[16] compute_gradients
@ ~/.julia/packages/Lux/wjsb3/src/helpers/training.jl:205 [inlined]
[17] single_train_step_impl!(backend::AutoEnzyme{…}, obj_fn::typeof(_loss), data::ConcretePJRTArray{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training ~/.julia/packages/Lux/wjsb3/src/helpers/training.jl:331
[18] #single_train_step!#6
@ ~/.julia/packages/Lux/wjsb3/src/helpers/training.jl:297 [inlined]
[19] single_train_step!(backend::AutoEnzyme{…}, obj_fn::typeof(_loss), data::ConcretePJRTArray{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training ~/.julia/packages/Lux/wjsb3/src/helpers/training.jl:291
[20] macro expansion
@ ./timing.jl:581 [inlined]
[21] top-level scope
@ ~/code/software/Flows/enzyme_masking_layer.jl:38
Some type information was truncated. Use `show(err)` to see complete types.