Lux/Enzyme error when Training my model?

Hi,
I’m trying to get a simple language model to train using Lux/Enzyme and get a big scary error. Does anyone have insight into how to get to the bottom of this? Is there an error in the MWE?
I tried Zygote and it runs fine. Mooncake unfortunately errors in precompilation.
The error trace goes rather deep into the Enzyme stack and I have no clue how to approach this.

Error Message:

❯ julia --project=. src/training.jl
'apple-m4' is not a recognized processor for this target (ignoring processor)
'apple-m4' is not a recognized processor for this target (ignoring processor)
ERROR: LoadError: TypeError: non-boolean (Compiler.IRCode) used in boolean context
Stacktrace:
  [1] semiconcrete_result_item(result::Compiler.SemiConcreteResult, info::Compiler.CallInfo, flag::UInt32, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}})
    @ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1449
  [2] handle_semi_concrete_result!(cases::Vector{Compiler.InliningCase}, result::Compiler.SemiConcreteResult, match::Core.MethodMatch, info::Compiler.CallInfo, flag::UInt32, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}})
    @ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1462
  [3] handle_any_const_result!(cases::Vector{Compiler.InliningCase}, result::Any, match::Core.MethodMatch, argtypes::Vector{Any}, info::Compiler.CallInfo, flag::UInt32, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}}; allow_typevars::Bool)
    @ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1304
  [4] handle_any_const_result!
    @ ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1297 [inlined]
  [5] compute_inlining_cases(info::Compiler.CallInfo, flag::UInt32, sig::Compiler.Signature, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}})
    @ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1369
  [6] handle_call!
    @ ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1401 [inlined]
  [7] assemble_inline_todo!(ir::Compiler.IRCode, state::Compiler.InliningState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}})
    @ Compiler ./../usr/share/julia/Compiler/src/ssair/inlining.jl:1652
  [8] ssa_inlining_pass!
    @ ./../usr/share/julia/Compiler/src/ssair/inlining.jl:76 [inlined]
  [9] run_passes_ipo_safe(ci::Core.CodeInfo, sv::Compiler.OptimizationState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}}, optimize_until::Nothing)
    @ Compiler ./../usr/share/julia/Compiler/src/optimize.jl:1013
 [10] run_passes_ipo_safe
    @ ./../usr/share/julia/Compiler/src/optimize.jl:1027 [inlined]
 [11] optimize(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, opt::Compiler.OptimizationState{Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}}, caller::Compiler.InferenceResult)
    @ Compiler ./../usr/share/julia/Compiler/src/optimize.jl:1002
 [12] finish_nocycle(::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, frame::Compiler.InferenceState, time_before::UInt64)
    @ Compiler ./../usr/share/julia/Compiler/src/typeinfer.jl:202
 [13] typeinf(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, frame::Compiler.InferenceState)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:4507
 [14] const_prop_call(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, mi::Core.MethodInstance, result::Compiler.MethodCallResult, arginfo::Compiler.ArgInfo, sv::Compiler.InferenceState, concrete_eval_result::Nothing)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:1348
 [15] abstract_call_method_with_const_args(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, result::Compiler.MethodCallResult, f::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, match::Core.MethodMatch, sv::Compiler.InferenceState, invokecall::Nothing)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:898
 [16] abstract_call_method_with_const_args
    @ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:868 [inlined]
 [17] (::Compiler.var"#handle1#abstract_call_gf_by_type##1"{Int64, Compiler.Future{Compiler.MethodCallResult}, Int64, Vector{Union{Nothing, Core.CodeInstance}}, Core.MethodMatch, Compiler.ArgInfo, Compiler.StmtInfo, Compiler.CallInferenceState, Vector{Any}, Compiler.var"#tmerge##0#tmerge##1"{Compiler.InferenceLattice{Compiler.ConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.var"#tmerge##0#tmerge##1"{Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.var"#⊑##0#⊑##1"{Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.InferenceLattice{Compiler.ConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}, Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}})(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, sv::Compiler.InferenceState)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:178
 [18] (::Compiler.var"#infercalls#abstract_call_gf_by_type##0"{Compiler.ArgInfo, Compiler.StmtInfo, Compiler.CallInferenceState, Compiler.Future{Compiler.CallMeta}, Vector{Compiler.MethodMatchTarget}, Vector{Any}, Compiler.var"#tmerge##0#tmerge##1"{Compiler.InferenceLattice{Compiler.ConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.var"#tmerge##0#tmerge##1"{Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.var"#⊑##0#⊑##1"{Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}}, Compiler.InferenceLattice{Compiler.ConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}, Compiler.InferenceLattice{Compiler.InterConditionalsLattice{Compiler.PartialsLattice{Compiler.ConstsLattice}}}})(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, sv::Compiler.InferenceState)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:252
 [19] abstract_call_gf_by_type(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, func::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, atype::Any, sv::Compiler.InferenceState, max_methods::Int64)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:338
 [20] abstract_call_gf_by_type(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter, f::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, atype::Any, sv::Compiler.InferenceState, max_methods::Int64)
    @ Enzyme.Compiler.Interpreter ~/.julia/packages/Enzyme/eJcor/src/compiler/interpreter.jl:364
 [21] abstract_call_known(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, f::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, sv::Compiler.InferenceState, max_methods::Int64)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:2782
 [22] abstract_call_known(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, f::Any, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, sv::Compiler.InferenceState, max_methods::Int64)
    @ Enzyme.Compiler.Interpreter ~/.julia/packages/Enzyme/eJcor/src/compiler/interpreter.jl:1186
 [23] abstract_call(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, arginfo::Compiler.ArgInfo, si::Compiler.StmtInfo, sv::Compiler.InferenceState, max_methods::Int64)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:2889
 [24] abstract_call
    @ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:2882 [inlined]
 [25] abstract_call(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, arginfo::Compiler.ArgInfo, sstate::Compiler.StatementState, sv::Compiler.InferenceState)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3042
 [26] abstract_eval_call
    @ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3060 [inlined]
 [27] abstract_eval_statement_expr(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, e::Expr, sstate::Compiler.StatementState, sv::Compiler.InferenceState)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3389
 [28] abstract_eval_basic_statement
    @ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3835 [inlined]
 [29] abstract_eval_basic_statement
    @ ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:3792 [inlined]
 [30] typeinf_local(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, frame::Compiler.InferenceState, nextresult::Compiler.CurrentState)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:4342
 [31] typeinf(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, frame::Compiler.InferenceState)
    @ Compiler ./../usr/share/julia/Compiler/src/abstractinterpretation.jl:4500
 [32] typeinf_ext(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, mi::Core.MethodInstance, source_mode::UInt8)
    @ Compiler ./../usr/share/julia/Compiler/src/typeinfer.jl:1259
 [33] typeinf_type
    @ ./../usr/share/julia/Compiler/src/typeinfer.jl:1281 [inlined]
 [34] return_type(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, mi::Core.MethodInstance)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:12
 [35] primal_return_type_world(mode::Mode, world::UInt64, mi::Core.MethodInstance)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:82
 [36] primal_return_type_generator(world::UInt64, source::Any, self::Any, mode::Type, ft::Type, tt::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/eJcor/src/typeutils/inference.jl:120
 [37] autodiff
    @ ~/.julia/packages/Enzyme/eJcor/src/Enzyme.jl:387 [inlined]
 [38] compute_gradients_impl(ad::AutoEnzyme{Nothing, Nothing}, obj_fn::typeof(loss_fun), data::Tuple{Matrix{Int64}, Matrix{Int64}}, ts::Lux.Training.TrainState{Nothing, Nothing, Embedding{Int64, Int64, typeof(rand32)}, @NamedTuple{weight::Matrix{Float32}}, @NamedTuple{}, Adam{Float32, Tuple{Float64, Float64}, Float64}, @NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}})
    @ LuxEnzymeExt ~/.julia/packages/Lux/GOH1t/ext/LuxEnzymeExt/training.jl:17
 [39] #compute_gradients#1
    @ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:277 [inlined]
 [40] compute_gradients
    @ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:275 [inlined]
 [41] single_train_step_impl!
    @ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:425 [inlined]
 [42] #single_train_step!#6
    @ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:384 [inlined]
 [43] single_train_step!
    @ ~/.julia/packages/Lux/GOH1t/src/helpers/training.jl:373 [inlined]
 [44] train(tstate::Lux.Training.TrainState{Nothing, Nothing, Embedding{Int64, Int64, typeof(rand32)}, @NamedTuple{weight::Matrix{Float32}}, @NamedTuple{}, Adam{Float32, Tuple{Float64, Float64}, Float64}, @NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}}}, vjp::AutoEnzyme{Nothing, Nothing}, data_set::Vector{Int64}, num_epochs::Int64)
    @ Main ~/julia_envs/nanolux/src/training.jl:91
 [45] top-level scope
    @ ~/julia_envs/nanolux/src/training.jl:98
 [46] include(mod::Module, _path::String)
    @ Base ./Base.jl:306
 [47] exec_options(opts::Base.JLOptions)
    @ Base ./client.jl:317
 [48] _start()
    @ Base ./client.jl:550
in expression starting at /Users/ralph/julia_envs/nanolux/src/training.jl:98

MWE


using ADTypes
using Enzyme
using Distributions
using Lux
using OneHotArrays
using Optimisers
using Random




function load_dataset(filename)

    lines = readlines(filename)

    _out = [c for l ∈ lines for c ∈ l]
    chars = sort!(unique(_out))
    push!(chars, '\n')
    vocab_size = length(chars)
    int_to_ch = Dict( ix => val for (ix, val) in enumerate(chars) )
    ch_to_int = Dict( val => ix for (ix, val) in enumerate(chars) )
    encode(str) = [ch_to_int[s] for s in str]
    decode(tok) = [int_to_ch[t] for t in tok]

    data = encode(join(lines, ""))
    return data, encode, decode, vocab_size
end

dataset, encode_fun, decode_fun, vocab_size = load_dataset("data/input.txt")



"""
    get_batch(dataset, batch_size, block_size)

Return batches of data from a dataset

# Arguments
- `dataset` - The dataset, either data_train or data_test
- `batch_size` - Number of independent sequences 
- `block_size` - Length of sequences
# Returns:
- `x` - 
"""
function get_batch(rng, dataset, batch_size, block_size)
    N_data = length(dataset) - block_size
    ix = rand(rng, 1:N_data, batch_size) # Generate batch_size number of random offsets
    x = stack([dataset[i:i+block_size-1] for i in ix], dims=2)   # Stack sequences with random offsets into a matrix
    y = stack([dataset[i+1:i+block_size] for i in ix], dims=2)

    return x, y
end

function loss_fun(model, ps, st, (xb, yb))
    # Outputs are interpreted as the logits.
    logits, _ = model(xb, ps, st)
    (C, T, B) = size(logits)
    logits_t = reshape(logits, C, T*B)  # Reshape, so that dim2 is along separate samples
    yb_t = reshape(yb, T*B)
    # Here we have to resort to OneHotArrays to get the CrossEntropy to work.
    oh = onehotbatch(yb_t, 1:65)
    # USe a negative log-likelihood loss function
    loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(logits_t, oh)
    return loss, st, NamedTuple()
end



N = length(dataset)
N_train = Int(round(0.9 * N))
data_train = dataset[1:N_train]
data_test = dataset[N_train+1:end]

 
rng = Random.default_rng()
Random.seed!(rng, 1337)

model = Embedding(vocab_size => vocab_size)
ps, st = Lux.setup(rng, model)
opt = Adam(0.03f0)

# TrainState is a useful struct defined by lux. It is essentially a warpper over parameters, state, optimizer state,
# and the model. We only need to pass this into our training function
tstate = Training.TrainState(model, ps, st, opt)
vjp_rule = AutoEnzyme()

function train(tstate::Training.TrainState, vjp, data_set, num_epochs)
    for epoch in 1:num_epochs
        xb, yb = get_batch(rng, data_set, 4, 8)
        _, loss, _, _tstate = Training.single_train_step!(vjp, loss_fun, (xb, yb), tstate)
        println("Epoch: $(epoch)    Loss: $(loss)")
    end
    return tstate
end


train(tstate, AutoEnzyme(), data_train, 10)

Versioninfo():

julia> versioninfo()
Julia Version 1.12.1
Commit ba1e628ee49 (2025-10-17 13:02 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 14 × Apple M4 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, apple-m4)
  GC: Built with stock GC
Threads: 1 default, 1 interactive, 1 GC (on 10 virtual cores)
Environment:
  JULIA_ENVS = /Users/ralph/julia_envs
  JULIA_PKG_USE_CLI_GIT = true

1.12 isn’t supported yet, try 1.11 or 1.10?

Thanks.
It compiles and runs with 1.11.7 and spits out a GC error after some 100 epochs:

julia> include("src/training.jl")
Epoch: 100    Loss: 3.4096978
Epoch: 200    Loss: 2.5341659
GC error (probable corruption)
Allocations: 101018814 (Pool: 101011456; Big: 7358); GC: 83
<?#0xc3482b1c0::<circular reference @-1>>

thread 0 ptr queue:
~~~~~~~~~~ ptr queue top ~~~~~~~~~~
Memory{Float32}(2080, 0xc30d6f000)[0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0,
[...]
==========
Memory{Float32}(2080, 0xc30d71800)[-2.51238298f0, -6.76320791f0, -7.07042551f0, -6.63637829f0, -4.83658981f0, -4.41381359f0, -6.26985502f0, -4.76020432f0, -6.45876646f0, -5.40435886f0, -5.90047693f0, -6.24415064f0, -6.46926832f0, -6.75609636f0, -6.75700712f0, -6.16888428f0, -6.51122808f0, -6.69230604f0, -6.98521471f0, -6.16071367f0, -7.10040236f0, -6.47086763f0, -7.00420952f0, -6.47163057f0, -6.84609175f0, -6.54321384f0, -6.74749613f0, -7.0368371f0, -6.36307669f0, -6.79174519f0, -6.98961115f0, -6.16283941f0, -6.8906436f0, -6.32030058f0, -6.56946182f0, -6.24793053f0, -6.52841425f0, -6.93595123f0, -1.67428231f0, -6.427526f0, -6.55315447f0, -5.82726955f0, -1.05676889f0, -6.22503138

[...]
2574848f0, -0.34315449f0, -0.137290001f0, -0.578015149f0, -2.97973323f0, -2.48788214f0, -2.99111533f0, -0.0648599342f0, -2.88224602f0, -2.84659648f0]
==========
Memory{Float32}(32, 0x142eef360)[2.83892941f0, 1.60529912f0, 2.11458898f0, 2.10521436f0, 2.40639472f0, 2.11458898f0, 2.10521436f0, 2.60500717f0, 1.70545506f0, 2.81477737f0, 2.13002634f0, 2.13002634f0, 2.11458898f0, 2.17010045f0, 2.81477737f0, 2.60500717f0, 2.22852421f0, 1.82820141f0, 2.04361725f0, 2.60500717f0, 2.81477737f0, 2.78560662f0, 2.11458898f0, 2.83892941f0, 1.88166904f0, 1.70545506f0, 2.34638739f0, 1.81718957f0, 2.81477737f0, 3.01865745f0, 2.11458898f0, 2.60500717f0]
==========
Memory{Float32}(32, 0x142eef400)[0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0, 0f0]
==========
~~~~~~~~~~ ptr queue bottom ~~~~~~~~~~

[95919] signal 6: Abort trap: 6
in expression starting at /Users/ralph/julia_envs/nanolux/src/training.jl:102
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 101018814 (Pool: 101011456; Big: 7358); GC: 83
[1]    95919 abort      julia --project=.

In a second run I got a segfault:

❯ julia --project=. src/training.jl
Epoch: 100    Loss: 3.4096978

[96267] signal 11 (2): Segmentation fault: 11
in expression starting at /Users/ralph/julia_envs/nanolux/src/training.jl:102
gc_mark_obj8 at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/gc.c:0
gc_mark_outrefs at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/gc.c:2882 [inlined]
gc_mark_loop_serial_ at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/gc.c:2938
gc_mark_loop_serial at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/gc.c:2961
_jl_gc_collect at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/gc.c:3532
ijl_gc_collect at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/gc.c:3893
maybe_collect at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/gc.c:926 [inlined]
jl_gc_pool_alloc_inner at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/gc.c:1319 [inlined]
jl_gc_pool_alloc_noinline at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/gc.c:1386 [inlined]
jl_gc_alloc_ at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/./julia_internal.h:523
_new_genericmemory_ at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/genericmemory.c:56
GenericMemory at ./boot.jl:516 [inlined]
new_as_memoryref at ./boot.jl:535 [inlined]
Array at ./boot.jl:582 [inlined]
Array at ./boot.jl:592 [inlined]
similar at ./array.jl:372 [inlined]
similar at ./abstractarray.jl:830 [inlined]
reducedim_initarray at ./reducedim.jl:53
reducedim_initarray at ./reducedim.jl:54 [inlined]
reducedim_init at ./reducedim.jl:189 [inlined]
_mapreduce_dim at ./reducedim.jl:343 [inlined]
#mapreduce#928 at ./reducedim.jl:329 [inlined]
mapreduce at ./reducedim.jl:329 [inlined]
#_sum#962 at ./reducedim.jl:1011 [inlined]
_sum at ./reducedim.jl:1011 [inlined]
#_sum#961 at ./reducedim.jl:1010 [inlined]
_sum at ./reducedim.jl:1010 [inlined]
#sum#935 at ./reducedim.jl:982 [inlined]
sum at ./reducedim.jl:982 [inlined]
compute_cross_entropy at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/losses.jl:412 [inlined]
compute_cross_entropy at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/losses.jl:0 [inlined]
augmented_julia_compute_cross_entropy_15327_inner_20wrap at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/losses.jl:0
macro expansion at /Users/ralph/.julia/packages/Enzyme/Ah2fT/src/compiler.jl:5883 [inlined]
enzyme_call at /Users/ralph/.julia/packages/Enzyme/Ah2fT/src/compiler.jl:5417 [inlined]
AugmentedForwardThunk at /Users/ralph/.julia/packages/Enzyme/Ah2fT/src/compiler.jl:5367 [inlined]
macro expansion at /Users/ralph/.julia/packages/Enzyme/Ah2fT/src/rules/jitrules.jl:447 [inlined]
runtime_generic_augfwd at /Users/ralph/.julia/packages/Enzyme/Ah2fT/src/rules/jitrules.jl:574
unknown function (ip: 0x154180eb3)
unsafe_apply_loss at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/losses.jl:408 [inlined]
AbstractLossFunction at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/losses.jl:196 [inlined]
loss_fun at /Users/ralph/julia_envs/nanolux/src/training.jl:65
#5 at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/training.jl:338 [inlined]
#5 at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/training.jl:0 [inlined]
diffejulia__5_2848_inner_13wrap at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/training.jl:0
macro expansion at /Users/ralph/.julia/packages/Enzyme/Ah2fT/src/compiler.jl:5883 [inlined]
enzyme_call at /Users/ralph/.julia/packages/Enzyme/Ah2fT/src/compiler.jl:5417 [inlined]
CombinedAdjointThunk at /Users/ralph/.julia/packages/Enzyme/Ah2fT/src/compiler.jl:5303 [inlined]
autodiff at /Users/ralph/.julia/packages/Enzyme/Ah2fT/src/Enzyme.jl:521 [inlined]
compute_gradients_impl at /Users/ralph/.julia/packages/Lux/GOH1t/ext/LuxEnzymeExt/training.jl:17 [inlined]
#compute_gradients#1 at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/training.jl:277 [inlined]
compute_gradients at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/training.jl:275 [inlined]
single_train_step_impl! at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/training.jl:425 [inlined]
#single_train_step!#7 at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/training.jl:384 [inlined]
single_train_step! at /Users/ralph/.julia/packages/Lux/GOH1t/src/helpers/training.jl:373 [inlined]
train at /Users/ralph/julia_envs/nanolux/src/training.jl:92
unknown function (ip: 0x1533144cf)
jl_apply at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/./julia.h:2157 [inlined]
do_call at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_stmt_value at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:174
eval_body at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:666
jl_interpret_toplevel_thunk at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/interpreter.c:824
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:952 [inlined]
ijl_toplevel_eval_in at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
include_string at ./loading.jl:2734
_include at ./loading.jl:2794
include at ./Base.jl:562
jfptr_include_47091.1 at /Users/ralph/.julia/juliaup/julia-1.11.7+0.aarch64.apple.darwin14/lib/julia/sys.dylib (unknown line)
exec_options at ./client.jl:323
_start at ./client.jl:531
jfptr__start_73909.1 at /Users/ralph/.julia/juliaup/julia-1.11.7+0.aarch64.apple.darwin14/lib/julia/sys.dylib (unknown line)
jl_apply at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/./julia.h:2157 [inlined]
true_main at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XG3Q6T6R70.0/build/default-honeycrisp-XG3Q6T6R70-0/julialang/julia-release-1-dot-11/src/jlapi.c:1059
Allocations: 57090319 (Pool: 57087890; Big: 2429); GC: 31
[1]    96267 segmentation fault  julia --project=. src/training.jl

I’m writing the next part down because I had to figure it out. To switch environments from 1.12 to 1.11 do

  1. Install julia-1.11: juliaup add 1.11.7
  2. Make sure you have a Manifest-1.12.toml so that 1.11 will write to it’s separate one.
  3. Set an override for the current directory: juliaup override set 1.11.7
  4. Resolve and instantiate the environment for 1.11:
    julia --project=. -e 'using Pkg; Pkg.resolve(); Pkg.instantiate();
  5. Run your program julia --project=. src/training.jl

You should try using reactant for best performance (that also bypasses these GC errors). See Compiling Lux Models using Reactant.jl | Lux.jl Docs.

1 Like

Thanks,
One question about the training API:
This fragment states that for the training API the loss function takes 4 arguments:
Fitting a Polynomial using MLP | Lux.jl Docs
The fragment above uses a loss function with 5 parameters (one additional because the data x and y are passed individual, not as a tuple). Is there a strict definition on loss function parameters from the API?

I’ve tried this

const xdev = reactant_device()

dataset, encode_fun, decode_fun, vocab_size = load_dataset("data/input.txt")
N = length(dataset)
N_train = Int(round(0.9 * N))
data_train = dataset[1:N_train]
data_test = dataset[N_train+1:end]

 
rng = Random.default_rng()
Random.seed!(rng, 1337)

model = Embedding(vocab_size => vocab_size)
ps, st = Lux.setup(rng, model)
opt = Adam(0.03f0)


# Test the model on reactant arrays
xb, yb = get_batch(rng, dataset, 4, 8)

xb_ra = xb |> xdev
yb_ra = yb |> xdev
ps_ra = ps |> xdev
st_ra = st |> xdev

model_compiled = @compile model(xb_ra, ps_ra, Lux.testmode(st))
pred_compiled, _ = model_compiled(xb_ra, ps_ra, Lux.testmode(st))


# Test regular loss function
loss_fun(model, ps, st, xb, yb)

function enzyme_gradient(model, ps, st, xb, yb)
    return Enzyme.gradient(Enzyme.Reverse, Const(loss_fun), Const(model), ps, Const(st), Const(xb), Const(yb))[2]
end

enzyme_gradient_compiled = @compile enzyme_gradient(model, ps_ra, st_ra, xb_ra, yb_ra)

The last line errors from somewhere in onehotbatch with

julia> enzyme_gradient_compiled = @compile enzyme_gradient(model, ps_ra, st_ra, xb_ra, yb_ra)
ERROR: MethodError: getindex(::Base.ReshapedArray{Reactant.TracedRNumber{Int64}, 1, Reactant.TracedRArray{Int64, 2}, Tuple{}}, ::Int64) is ambiguous.

Candidates:
  getindex(a::Base.ReshapedArray{Reactant.TracedRNumber{T}} where T, indices::Union{Int64, Reactant.TracedRNumber{Int64}}...)
    @ Reactant.TracedIndexing ~/.julia/packages/Reactant/IgTfV/src/Indexing.jl:113
  getindex(a::Base.ReshapedArray{Reactant.TracedRNumber{T}, N, P, Tuple{}} where {T, N, P<:AbstractArray}, indices::Int64)
    @ Reactant.TracedIndexing ~/.julia/packages/Reactant/IgTfV/src/Indexing.jl:125
  getindex(A::Base.ReshapedArray{T, N}, indices::Vararg{Int64, N}) where {T, N}
    @ Base reshapedarray.jl:259
  getindex(a::Base.ReshapedArray{Reactant.TracedRNumber{T}} where T, indices...)
    @ Reactant.TracedIndexing ~/.julia/packages/Reactant/IgTfV/src/Indexing.jl:117
  getindex(A::Base.ReshapedArray{T, N, P, Tuple{}} where {T, N, P<:AbstractArray}, index::Int64)
    @ Base reshapedarray.jl:253
  getindex(a::AbstractArray{Reactant.TracedRNumber{T}, N}, index::Vararg{Union{Int64, Reactant.TracedRNumber{Int64}}, N}) where {T, N}
    @ Reactant.TracedIndexing ~/.julia/packages/Reactant/IgTfV/src/Indexing.jl:83

Possible fix, define
  getindex(::Base.ReshapedArray{Reactant.TracedRNumber{T}, 1, P, Tuple{}} where P<:AbstractArray, ::Int64) where T

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/IgTfV/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::typeof(getindex), ::Base.ReshapedArray{…}, ::Int64)
    @ Reactant ~/.julia/packages/Reactant/IgTfV/src/utils.jl:893
  [3] iterate
    @ ./abstractarray.jl:1209 [inlined]
  [4] iterate
    @ ./abstractarray.jl:1207 [inlined]
  [5] _onehotbatch
    @ ~/.julia/packages/OneHotArrays/RAmnA/src/onehot.jl:87 [inlined]
  [6] onehotbatch
    @ ~/.julia/packages/OneHotArrays/RAmnA/src/onehot.jl:84 [inlined]
  [7] (::Nothing)(none::typeof(onehotbatch), none::Base.ReshapedArray{…}, none::UnitRange{…}, none::Tuple{})
    @ Reactant ./<missing>:0
  [8] getproperty
    @ ./Base.jl:49 [inlined]
  [9] first
    @ ./range.jl:841 [inlined]
 [10] length
    @ ./range.jl:765 [inlined]
 [11] onehotbatch
    @ ~/.julia/packages/OneHotArrays/RAmnA/src/onehot.jl:84 [inlined]
 [12] call_with_reactant(::typeof(onehotbatch), ::Base.ReshapedArray{…}, ::UnitRange{…})
    @ Reactant ~/.julia/packages/Reactant/IgTfV/src/utils.jl:0
 [13] loss_fun
    @ ./REPL[11]:8 [inlined]
 [14] (::Nothing)(none::typeof(loss_fun), none::Embedding{…}, none::@NamedTuple{…}, none::@NamedTuple{}, none::Reactant.TracedRArray{…}, none::Reactant.TracedRArray{…})
    @ Reactant ./<missing>:0
 [15] GenericMemory
    @ ./boot.jl:514 [inlined]
 [16] IdSet
    @ ./idset.jl:31 [inlined]
 [17] IdSet
    @ ./idset.jl:35 [inlined]
 [18] is_traced
    @ ~/.julia/packages/ReactantCore/SEqVX/src/ReactantCore.jl:10 [inlined]
 [19] _any_tuple
    @ ./reduce.jl:1259 [inlined]
 [20] any
    @ ./reduce.jl:1255 [inlined]
 [21] Embedding
    @ ./none:0 [inlined]
 [22] loss_fun
    @ ./REPL[11]:3 [inlined]
 [23] call_with_reactant(::typeof(loss_fun), ::Embedding{…}, ::@NamedTuple{…}, ::@NamedTuple{}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/IgTfV/src/utils.jl:0
 [24] make_mlir_fn(f::typeof(loss_fun), args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Nothing, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/IgTfV/src/TracedUtils.jl:345
 [25] make_mlir_fn
    @ ~/.julia/packages/Reactant/IgTfV/src/TracedUtils.jl:275 [inlined]
 [26] overload_autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Duplicated{…}, ::Const{…}, ::Const{…}, ::Const{…})
    @ Reactant ~/.julia/packages/Reactant/IgTfV/src/Enzyme.jl:315
 [27] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Duplicated{…}, ::Const{…}, ::Const{…}, ::Const{…})
    @ Reactant ~/.julia/packages/Reactant/IgTfV/src/Overlay.jl:21
 [28] macro expansion
    @ ~/.julia/packages/Enzyme/Ah2fT/src/sugar.jl:286 [inlined]
 [29] gradient
    @ ~/.julia/packages/Enzyme/Ah2fT/src/sugar.jl:273 [inlined]
 [30] enzyme_gradient
    @ ./REPL[36]:2 [inlined]
 [31] (::Nothing)(none::typeof(enzyme_gradient), none::Embedding{…}, none::@NamedTuple{…}, none::@NamedTuple{}, none::Reactant.TracedRArray{…}, none::Reactant.TracedRArray{…})
    @ Reactant ./<missing>:0
 [32] Const
    @ ~/.julia/packages/EnzymeCore/Tb1OZ/src/EnzymeCore.jl:30 [inlined]
 [33] enzyme_gradient
    @ ./REPL[36]:2 [inlined]
 [34] call_with_reactant(::typeof(enzyme_gradient), ::Embedding{…}, ::@NamedTuple{…}, ::@NamedTuple{}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/IgTfV/src/utils.jl:0
 [35] make_mlir_fn(f::typeof(enzyme_gradient), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/IgTfV/src/TracedUtils.jl:345
 [36] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1605
 [37] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1572 [inlined]
 [38] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:3494
 [39] compile_xla
    @ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:3467 [inlined]
 [40] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:3569
 [41] top-level scope
    @ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:2644
Some type information was truncated. Use `show(err)` to see complete types.

The 4 arg version is the correct one. This error should be fixable on Reactant end, I will take a look