Loading resampled data for PINN in Reactant + Enzyme + Lux

I am implementing a neural network using Lux and Reactant + Enzyme. At each epoch, I need to resample the data, divide it in batches, and take a training step. For now, I am working on a CPU. What is the best way to do this efficiently?

A MWE of my current implementation is the following. I first define a model…

using Lux, Optimisers, MLUtils
using Reactant, Enzyme
using Printf
using Random
rng = Xoshiro(42)

const dev = reactant_device()

model = Chain(
    Dense(1 => 64, gelu),
    Dense(64 => 64, gelu),
    Dense(64 => 1)
)

ps, st = Lux.setup(rng, model) |> dev

…then, I initialise a Matrix which stores the data and the output…

function resample!(x, σ, h)
    n = size(x, 2)
    @inbounds for i in 1:(n - 1)
        x[i + 1] = x[i] .+ √h .* randn(rng) * σ
    end
    return x
end


x = Matrix{Float32}(undef, 1, 3200);
y = similar(x);

…finally, I train the model…

target(x) = log(1 + exp(-x))

ts = Training.TrainState(model, ps, st, Adam(1f-3));
for epoch in 1:1_000

    resample!(x, 0.1f0, 0.01f0)
    y .= target.(x)
    dataloader = DeviceIterator(dev, DataLoader((x, y); batchsize = 32))

    iterloss = 0f0
    for (xᵢ, yᵢ) in dataloader
        _, batchloss, _, ts = Training.single_train_step!(AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), ts)

        iterloss += batchloss / length(dataloader)
    end

    if epoch % 100 == 0
        @printf("Epoch %d, loss = %.6f\n", epoch, iterloss)
    end
end

Is there a more efficient way to do this?

you could compile the whole dataloader loop which I’m pretty sure Lux has a method for, also pay attention to FAQs | Reactant.jl especially the gc part you may need it

This looks fine to me. If you think there is some bottleneck, you can wrap the whole epoch loop in a Reactant.with_profiler() that should give you a lot of insights.

you could compile the whole dataloader loop which I’m pretty sure Lux has a method for,

That part is not implemented, though we have plans for that. The main trick to do there is to stack the batches for n iterations and used a traced loop, and then we apply some fancy AD loop optimizations in the backend.