Hi,
I’m trying to get a simple language model to train using Lux/Enzyme and get a big scary error. Does anyone have insight into how to get to the bottom of this? Is there an error in the MWE?
I tried Zygote and it runs fine. Mooncake unfortunately errors in precompilation.
The error trace goes rather deep into the Enzyme stack and I have no clue how to approach this.
Error Message:
❯ julia --project=. src/training.jl
'apple-m4' is not a recognized processor for this target (ignoring processor)
'apple-m4' is not a recognized processor for this target (ignoring processor)
ERROR: LoadError: TypeError: non-boolean (Compiler.IRCode) used in boolean context
Stacktrace:
[1] semiconcrete_result_item(result::Compiler.SemiConcreteResult, info::Compiler.CallInfo, flag::UInt32, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}})
@ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1449
[2] handle_semi_concrete_result!(cases::Vector{Compiler.InliningCase}, result::Compiler.SemiConcreteResult, match::Core.MethodMatch, info::Compiler.CallInfo, flag::UInt32, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}})
@ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1462
[3] handle_any_const_result!(cases::Vector{Compiler.InliningCase}, result::Any, match::Core.MethodMatch, argtypes::Vector{Any}, info::Compiler.CallInfo, flag::UInt32, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}}; allow_typevars::Bool)
@ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1304
[4] handle_any_const_result!
@ ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1297 [inlined]
[5] compute_inlining_cases(info::Compiler.CallInfo, flag::UInt32, sig::Compiler.Signature, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}})
@ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1369
[6] handle_call!
@ ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1401 [inlined]
[7] assemble_inline_todo!(ir::Compiler.IRCode, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}})
@ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1652
[8] ssa_inlining_pass!
@ ./../usr/share/julia/Compiler/src/ssair/inlining.jl:76 [inlined]
[9] run_passes_ipo_safe(ci::Core.CodeInfo, sv::Compiler.OptimizationState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}}, optimize_until::Nothing)
@ Compiler ./../usr/share/julia/Compiler/src/optimize.jl:1013
[10] run_passes_ipo_safe
@ ./../usr/share/julia/Compiler/src/optimize.jl:1027 [inlined]
[11] optimize(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, opt::Compiler.OptimizationState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}}, caller::Compiler.InferenceResult)
@ Compiler ./../usr/share/julia/Compiler/src/optimize.jl:1002
[12] finish_nocycle(::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, frame::Compiler.InferenceState, time_before::UInt64)
@ Compiler ./../usr/share/julia/Compiler/src/typeinfer.jl:202
[13] typeinf(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, frame::Compiler.InferenceState)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:4507
[14] const_prop_call(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, mi::Core.MethodInstance, result::Compiler.MethodCallResult, arginfo::Compiler.ArgInfo, sv::Compiler.InferenceState, concrete_eval_result::Nothing)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:1348
[15] abstract_call_method_with_const_args(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, result::Compiler.MethodCallResult, f::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, match::Core.MethodMatch, sv::Compiler.InferenceState, invokecall::Nothing)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:898
[16] abstract_call_method_with_const_args
@ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:868 [inlined]
[17] (::Compiler.var"#handle1#abstract_call_gf_by_type##1"{Int64, Compiler.Future{Compiler.MethodCallResult}, Int64, Vector{Union{Nothing, Core.CodeInstance}}, Core.MethodMatch, Compiler.ArgInfo, Compiler.StmtInfo, Compiler.CallInferenceState, Vector{Any}, Compiler.var"#tmerge##0#tmerge##1"{Compiler.InferenceLattice{Compiler.ConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.var"#tmerge##0#tmerge##1"{Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.var"#⊑##0#⊑##1"{Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.InferenceLattice{Compiler.ConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}, Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}})(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, sv::Compiler.InferenceState)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:178
[18] (::Compiler.var"#infercalls#abstract_call_gf_by_type##0"{Compiler.ArgInfo, Compiler.StmtInfo, Compiler.CallInferenceState, Compiler.Future{Compiler.CallMeta}, Vector{Compiler.MethodMatchTarget}, Vector{Any}, Compiler.var"#tmerge##0#tmerge##1"{Compiler.InferenceLattice{Compiler.ConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.var"#tmerge##0#tmerge##1"{Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.var"#⊑##0#⊑##1"{Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.InferenceLattice{Compiler.ConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}, Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}})(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, sv::Compiler.InferenceState)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:252
[19] abstract_call_gf_by_type(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, func::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, atype::Any, sv::Compiler.InferenceState, max_methods::Int64)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:338
[20] abstract_call_gf_by_type(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter, f::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, atype::Any, sv::Compiler.InferenceState, max_methods::Int64)
@ Enzyme.Compiler.Interpreter ~/.julia/packages/Enzyme/eJcor/src/compiler/interpreter.jl:364
[21] abstract_call_known(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, f::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, sv::Compiler.InferenceState, max_methods::Int64)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:2782
[22] abstract_call_known(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, f::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, sv::Compiler.InferenceState, max_methods::Int64)
@ Enzyme.Compiler.Interpreter ~/.julia/packages/Enzyme/eJcor/src/compiler/interpreter.jl:1186
[23] abstract_call(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, sv::Compiler.InferenceState, max_methods::Int64)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:2889
[24] abstract_call
@ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:2882 [inlined]
[25] abstract_call(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, arginfo::Compiler.ArgInfo, sstate::Compiler.StatementState, sv::Compiler.InferenceState)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3042
[26] abstract_eval_call
@ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3060 [inlined]
[27] abstract_eval_statement_expr(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, e::Expr, sstate::Compiler.StatementState, sv::Compiler.InferenceState)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3389
[28] abstract_eval_basic_statement
@ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3835 [inlined]
[29] abstract_eval_basic_statement
@ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3792 [inlined]
[30] typeinf_local(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, frame::Compiler.InferenceState, nextresult::Compiler.CurrentState)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:4342
[31] typeinf(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, frame::Compiler.InferenceState)
@ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:4500
[32] typeinf_ext(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, mi::Core.MethodInstance, source_mode::UInt8)
@ Compiler ./../usr/share/julia/Compiler/src/typeinfer.jl:1259
[33] typeinf_type
@ ./../usr/share/julia/Compiler/src/typeinfer.jl:1281 [inlined]
[34] return_type(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, mi::Core.MethodInstance)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:12
[35] primal_return_type_world(mode::Mode, world::UInt64, mi::Core.MethodInstance)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:82
[36] primal_return_type_generator(world::UInt64, source::Any, self::Any, mode::Type, ft::Type, tt::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:120
[37] autodiff
@ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:387 [inlined]
[38] compute_gradients_impl(ad::AutoEnzyme{Nothing, Nothing}, obj_fn::typeof(loss_fun), data::Tuple{Matrix{Int64}, Matrix{Int64}}, ts::Lux.Training.TrainState{Nothing, Nothing, Embedding{Int64, Int64, typeof(rand32)}, @NamedTuple{weight::Matrix{Float32}}, @NamedTuple{}, Adam{Float32, Tuple{Float64, Float64}, Float64}, @NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}})
@ LuxEnzymeExt ~/.julia/packages/Lux/GOH1t/ext/LuxEnzymeExt/training.jl:17
[39] #compute_gradients#1
@ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:277 [inlined]
[40] compute_gradients
@ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:275 [inlined]
[41] single_train_step_impl!
@ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:425 [inlined]
[42] #single_train_step!#6
@ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:384 [inlined]
[43] single_train_step!
@ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:373 [inlined]
[44] train(tstate::Lux.Training.TrainState{Nothing, Nothing, Embedding{Int64, Int64, typeof(rand32)}, @NamedTuple{weight::Matrix{Float32}}, @NamedTuple{}, Adam{Float32, Tuple{Float64, Float64}, Float64}, @NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}, vjp::AutoEnzyme{Nothing, Nothing}, data_set::Vector{Int64}, num_epochs::Int64)
@ Main ~/julia_envs/nanolux/src/training.jl:91
[45] top-level scope
@ ~/julia_envs/nanolux/src/training.jl:98
[46] include(mod::Module, _path::String)
@ Base ./Base.jl:306
[47] exec_options(opts::Base.JLOptions)
@ Base ./client.jl:317
[48] _start()
@ Base ./client.jl:550
in expression starting at /Users/ralph/julia_envs/nanolux/src/training.jl:98
MWE
using ADTypes
using Enzyme
using Distributions
using Lux
using OneHotArrays
using Optimisers
using Random
function load_dataset(filename)
lines = readlines(filename)
_out = [c for l ∈ lines for c ∈ l]
chars = sort!(unique(_out))
push!(chars, '\n')
vocab_size = length(chars)
int_to_ch = Dict( ix => val for (ix, val) in enumerate(chars) )
ch_to_int = Dict( val => ix for (ix, val) in enumerate(chars) )
encode(str) = [ch_to_int[s] for s in str]
decode(tok) = [int_to_ch[t] for t in tok]
data = encode(join(lines, ""))
return data, encode, decode, vocab_size
end
dataset, encode_fun, decode_fun, vocab_size = load_dataset("data/input.txt")
"""
get_batch(dataset, batch_size, block_size)
Return batches of data from a dataset
# Arguments
- `dataset` - The dataset, either data_train or data_test
- `batch_size` - Number of independent sequences
- `block_size` - Length of sequences
# Returns:
- `x` -
"""
function get_batch(rng, dataset, batch_size, block_size)
N_data = length(dataset) - block_size
ix = rand(rng, 1:N_data, batch_size) # Generate batch_size number of random offsets
x = stack([dataset[i:i+block_size-1] for i in ix], dims=2) # Stack sequences with random offsets into a matrix
y = stack([dataset[i+1:i+block_size] for i in ix], dims=2)
return x, y
end
function loss_fun(model, ps, st, (xb, yb))
# Outputs are interpreted as the logits.
logits, _ = model(xb, ps, st)
(C, T, B) = size(logits)
logits_t = reshape(logits, C, T*B) # Reshape, so that dim2 is along separate samples
yb_t = reshape(yb, T*B)
# Here we have to resort to OneHotArrays to get the CrossEntropy to work.
oh = onehotbatch(yb_t, 1:65)
# USe a negative log-likelihood loss function
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(logits_t, oh)
return loss, st, NamedTuple()
end
N = length(dataset)
N_train = Int(round(0.9 * N))
data_train = dataset[1:N_train]
data_test = dataset[N_train+1:end]
rng = Random.default_rng()
Random.seed!(rng, 1337)
model = Embedding(vocab_size => vocab_size)
ps, st = Lux.setup(rng, model)
opt = Adam(0.03f0)
# TrainState is a useful struct defined by lux. It is essentially a warpper over parameters, state, optimizer state,
# and the model. We only need to pass this into our training function
tstate = Training.TrainState(model, ps, st, opt)
vjp_rule = AutoEnzyme()
function train(tstate::Training.TrainState, vjp, data_set, num_epochs)
for epoch in 1:num_epochs
xb, yb = get_batch(rng, data_set, 4, 8)
_, loss, _, _tstate = Training.single_train_step!(vjp, loss_fun, (xb, yb), tstate)
println("Epoch: $(epoch) Loss: $(loss)")
end
return tstate
end
train(tstate, AutoEnzyme(), data_train, 10)
Versioninfo():
julia> versioninfo()
Julia Version 1.12.1
Commit ba1e628ee49 (2025-10-17 13:02 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: macOS (arm64-apple-darwin24.0.0)
CPU: 14 × Apple M4 Pro
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, apple-m4)
GC: Built with stock GC
Threads: 1 default, 1 interactive, 1 GC (on 10 virtual cores)
Environment:
JULIA_ENVS = /Users/ralph/julia_envs
JULIA_PKG_USE_CLI_GIT = true