I am implementing a neural network using Lux
and Reactant
+ Enzyme
. At each epoch, I need to resample the data, divide it in batches, and take a training step. For now, I am working on a CPU. What is the best way to do this efficiently?
A MWE of my current implementation is the following. I first define a model…
using Lux, Optimisers, MLUtils
using Reactant, Enzyme
using Printf
using Random
rng = Xoshiro(42)
const dev = reactant_device()
model = Chain(
Dense(1 => 64, gelu),
Dense(64 => 64, gelu),
Dense(64 => 1)
)
ps, st = Lux.setup(rng, model) |> dev
…then, I initialise a Matrix
which stores the data and the output…
function resample!(x, σ, h)
n = size(x, 2)
@inbounds for i in 1:(n - 1)
x[i + 1] = x[i] .+ √h .* randn(rng) * σ
end
return x
end
x = Matrix{Float32}(undef, 1, 3200);
y = similar(x);
…finally, I train the model…
target(x) = log(1 + exp(-x))
ts = Training.TrainState(model, ps, st, Adam(1f-3));
for epoch in 1:1_000
resample!(x, 0.1f0, 0.01f0)
y .= target.(x)
dataloader = DeviceIterator(dev, DataLoader((x, y); batchsize = 32))
iterloss = 0f0
for (xᵢ, yᵢ) in dataloader
_, batchloss, _, ts = Training.single_train_step!(AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), ts)
iterloss += batchloss / length(dataloader)
end
if epoch % 100 == 0
@printf("Epoch %d, loss = %.6f\n", epoch, iterloss)
end
end
Is there a more efficient way to do this?