I’m trying to make this code work,
using MLDatasets,Lux,Reactant,Enzyme,Random,OneHotArrays,MLUtils,Optimisers,Statistics
Reactant.set_default_backend("gpu")
const rng = Random.MersenneTwister(123)
const dev = xla_device()
# Load the MNIST dataset
mnist_train = MLDatasets.MNIST(:train)
mnist_test = MLDatasets.MNIST(:test)
# Make A MLP model
function make_mlp(num_layer::Int,input_dim ::Int,hidden_dim::Int,output_dim::Int)
model = Chain(
Dense(input_dim,hidden_dim,relu),
[Dense(hidden_dim,hidden_dim,relu) for i in 1:num_layer-2]...,
Dense(hidden_dim,output_dim)
)
ps,st = Lux.setup(rng,model) |> dev
return model,ps,st
end
function loss_fn(model,ps,st,d)
x,y = d
ŷ,_ = model(x,ps,st)
return Lux.MSELoss()(ŷ,y)
end
function main()
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 32
learning_rate = 1e-3
imgs_train = reshape(mnist_train.features,
(
size(mnist_train.features,1)*size(mnist_train.features,2),
size(mnist_train.features,3)
)
)
imgs_test = reshape(mnist_test.features,
(
size(mnist_test.features,1)*size(mnist_test.features,2),
size(mnist_test.features,3)
)
)
class = 0:9
targets_train = onehotbatch(mnist_train.targets,class)
model,ps,st = make_mlp(num_layers,28*28,hidden_dim,num_classes)
dl_train = DataLoader((imgs_train,targets_train),batchsize=batch_size,shuffle=true,partial=false)|> dev
opt = Training.TrainState(model,ps,st,Adam(learning_rate))
for epoch in 1:num_epochs
for (i,(x,y)) in enumerate(dl_train)
_,loss,_,opt = Training.single_train_step!(
AutoEnzyme(),
loss_fn,
(x,y),
opt
)
if i==1
@info "Epoch: $epoch, Loss: $loss"
end
end
end
ps = opt.parameters
targets = mnist_test.targets
modelcomp = @compile model(dev(imgs_test),ps,st)
ŷ,_ = modelcomp(dev(imgs_test),ps,st)
ŷ = argmax.(eachcol(Array(ŷ))) .- 1
acc = mean(ŷ .== targets)*100
@info "Accuracy: $acc"
return opt
end
opt = main()
It crash at this line
_,loss,_,opt = Training.single_train_step!(
AutoEnzyme(),
loss_fn,
(x,y),
opt
)
but works if I simply do
_,loss,_,opt = Training.single_train_step!(
AutoEnzyme(),
MSELoss(),
(x,y),
opt
)
I followed this tuto https://lux.csail.mit.edu/stable/manual/compiling_lux_models
trunk error :
BoundsError: attempt to access Reactant.TracedRNumber{Float32} at index [2]
Stacktrace:
[1] indexed_iterate
@ ./tuple.jl:101 [inlined]
[2] indexed_iterate(none::Reactant.TracedRNumber{Float32}, none::Int64, none::Nothing)
@ Reactant ./<missing>:0
[3] indexed_iterate
@ ./tuple.jl:101 [inlined]
[4] call_with_reactant(::typeof(Base.indexed_iterate), ::Reactant.TracedRNumber{Float32}, ::Int64, ::Nothing)
@ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
[5] wrapped_objective_function
@ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:9 [inlined]
[6] wrapped_objective_function(none::typeof(loss_fn), none::Chain{…}, none::@NamedTuple{…}, none::@NamedTuple{…}, none::Tuple{…}, none::LuxReactantExt.StatsAndNewStateWrapper)
@ Reactant ./<missing>:0
[7] wrapped_objective_function
@ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:9 [inlined]
[8] call_with_reactant(::typeof(LuxReactantExt.wrapped_objective_function), ::typeof(loss_fn), ::Chain{…}, ::@NamedTuple{…}, ::@NamedTuple{…}, ::Tuple{…}, ::LuxReactantExt.StatsAndNewStateWrapper)
@ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
[9] (::Reactant.TracedUtils.var"#8#18"{…})()
@ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:182
[10] block!(f::Reactant.TracedUtils.var"#8#18"{…}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Block.jl:201
[11] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:169
[12] make_mlir_fn
@ ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:86 [inlined]
[13] overload_autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Duplicated{…}, ::Const{…}, ::Const{…}, ::Const{…})
@ Reactant ~/.julia/packages/Reactant/WudhJ/src/Interpreter.jl:238
[14] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Duplicated{…}, ::Const{…}, ::Const{…}, ::Const{…})
@ Reactant ~/.julia/packages/Reactant/WudhJ/src/Overlay.jl:32
[15] macro expansion
@ ~/.julia/packages/Enzyme/DiEvV/src/sugar.jl:275 [inlined]
[16] gradient
@ ~/.julia/packages/Enzyme/DiEvV/src/sugar.jl:263 [inlined]
[17] compute_gradients_internal
@ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:17 [inlined]
[18] compute_gradients_internal_and_step!
@ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:115 [inlined]
[19] compute_gradients_internal_and_step!(none::typeof(loss_fn), none::Chain{…}, none::Tuple{…}, none::@NamedTuple{…}, none::@NamedTuple{…}, none::@NamedTuple{…})
@ Reactant ./<missing>:0
[20] compute_gradients_internal
@ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:16 [inlined]
[21] compute_gradients_internal_and_step!
@ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:115 [inlined]
[22] call_with_reactant(::typeof(LuxReactantExt.compute_gradients_internal_and_step!), ::typeof(loss_fn), ::Chain{…}, ::Tuple{…}, ::@NamedTuple{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
@ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
[23] (::Reactant.TracedUtils.var"#8#18"{…})()
@ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:182
[24] block!(f::Reactant.TracedUtils.var"#8#18"{…}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Block.jl:201
[25] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:169
[26] make_mlir_fn
@ ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:86 [inlined]
[27] #10
@ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:319 [inlined]
[28] block!(f::Reactant.Compiler.var"#10#15"{…}, blk::Reactant.MLIR.IR.Block)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Block.jl:201
[29] #9
@ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:318 [inlined]
[30] mmodule!(f::Reactant.Compiler.var"#9#14"{…}, blk::Reactant.MLIR.IR.Module)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Module.jl:92
[31] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:315
[32] compile_mlir!
@ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:314 [inlined]
[33] (::Reactant.Compiler.var"#32#34"{Bool, typeof(LuxReactantExt.compute_gradients_internal_and_step!), Tuple{…}})()
@ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:799
[34] context!(f::Reactant.Compiler.var"#32#34"{…}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Context.jl:76
[35] compile_xla(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:796
[36] compile_xla
@ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:791 [inlined]
[37] compile(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool, sync::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:823
[38] compile
@ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:822 [inlined]
[39] macro expansion
@ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:536 [inlined]
[40] single_train_step_impl!(backend::Lux.Training.ReactantBackend, objective_function::typeof(loss_fn), data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ LuxReactantExt ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:79
[41] single_train_step!(backend::AutoEnzyme{…}, obj_fn::typeof(loss_fn), data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training ~/.julia/packages/Lux/fMnM0/src/helpers/training.jl:276
[42] main()
@ Main ./REPL[9]:29
[43] top-level scope
@ REPL[10]:1