Lux re-implementation plateauing unexpectedly

I’ve been using the UCI HIGGS dataset to build a MLP classifier, but I’ve been running into a problem getting past a local minimum where my algorithms are plateauing with a cross entropy loos of around 0.5 (for both test and validation sets). I’ve already tried many different tactics: different architectures, optimizers, learning rates and momentum scheduling, stochastic kicks, batch homogenization, batch size annealing… the list goes on and on, to no avail. I’ve only recently been able to get an algorithm to overfit instead of plateauing (albeit it started overfitting at the plateau point).

For this reason, I’ve tried to implement the algorithm from the original paper, which has this Python code. I believe I’ve managed to copy the details over to a Lux implementation, but I’m seeing the same problem. The paper seems to have described quite easy success, so I suspect I have messed up the Lux (as this is my first experience with it).

Outward appearances is that everything is working correctly, with my GPU being utilized correctly and few log messages appearing on compilation. That said, my main suspects would be my @compact model definition, the learning rate decay / momentum schedule update implementation, or something in my training loop with Training.single_train_step!. I’ve tried commenting out the validation loop and related code, but that made no difference in the training model plateauing.

Reimplementation of Python Source Code
x_paper = Matrix{Float32}(higgs[1:2600000, 2:22]) |> permutedims
y_paper = Vector{Float32}(higgs[1:2600000, 1]) |> permutedims

const PAPERBATCHSIZE = 100
dim = 300

papermodel = @compact(
        w1 = 0.1f0randn(Float32, dim, 21), b1=zeros(Float32, dim, 1),
        w2 = 0.05f0randn(Float32, dim, dim), b2=zeros(Float32, dim, 1),
        w3 = 0.05f0randn(Float32, dim, dim), b3=zeros(Float32, dim, 1),
        w4 = 0.05f0randn(Float32, dim, dim), b4=zeros(Float32, dim, 1),
        w5 = 0.001f0randn(Float32, 1, dim), b5=zeros(Float32, 1, 1), act=tanh
    ) do x
        embed = act.(w1 * x .+ b1)
        embed = act.(w2 * embed .+ b2)
        embed = act.(w3 * embed .+ b3)
        embed = act.(w4 * embed .+ b4)
        @return sigmoid.(w5 * embed .+ b5)
    end

function train_papermodel!(model, x_train, y_train, x_val, y_val; 
                           features=21, epochs=200, lr=0.05, report=20)
    
    momentum = range(0.9, 0.99, length=200) # (odd) momentum schedule
    
    # training state initialization
    params, state = Lux.setup(Random.default_rng(), model) .|> dev
    optimizer = OptimiserChain(
        WeightDecay(1e-5),
        Momentum(lr, 0.9)
    )
    train_state = Training.TrainState(model, params, state, optimizer)
    
    # main data loader
    loader = DataLoader((x_train, y_train);
        batchsize=PAPERBATCHSIZE,
        shuffle=true,
        parallel=true,
        partial=false
    ) |> dev

    # validation loop accessories
    valoader = DataLoader((x_val, y_val);
        batchsize=10000,
        parallel=true,
        partial=false
    ) |> dev

    xv, yv = first(valoader)
    val_step_compiled = @compile model(xv, train_state.parameters, Lux.testmode(train_state.states))
    valbce = BinaryCrossEntropyLoss()
    
    # main loop
    for epoch in 1:epochs
        epoch ≤ 200 && Optimisers.adjust!(train_state, rho=momentum[epoch]) # momentum schedule
        
        # training loop
        trainloss, batches = zero(Float32), zero(Float32)
        for (x_gpu, y_gpu) in loader
            _, loss, _, train_state = Training.single_train_step!(
                AutoEnzyme(),
                BinaryCrossEntropyLoss(),
                (x_gpu, y_gpu),
                train_state
            )
            trainloss += loss
            batches += Float32(1)

            lr > 1e-6 && Optimisers.adjust!(train_state, eta=max(lr /= 1.0000002, 1e-6)) # learning rate decay
        end

        # validation loop
        if iszero(epoch % report) || isone(epoch) || epoch == epochs
            valstate = Lux.testmode(train_state.states)
            valoss, val_batches = zero(Float32), zero(Float32)
            for (x_gpu, y_gpu) in valoader
                ŷ, _ = val_step_compiled(x_gpu, train_state.parameters, valstate)
                valoss += valbce(Matrix(ŷ)[:], Matrix(y_gpu)[:])
                val_batches += Float32(1)
            end
            println("Epoch $epoch | Loss = $(trainloss / batches) | Val Loss = $(valoss / val_batches)")
        end
        
    end
    model, train_state
end

pmod, ptstate = train_papermodel!(papermodel, x_paper, y_paper, x_val, y_val; report=5, lr=0.05, epochs=1000);

I’ve been following the code from the tutorials online, so I don’t see any structural differences in what I’ve written, but I can certainly still be missing something. I’m hoping someone could tell me if there is a bug in my implementation that would lead to the model appearing to train at first, but quickly plateauing in the way I’ve mentioned.

Note: While I would be happy to hear hints on hyperparameter tuning or architecture ideas, etc, my main issue is that the algorithm above should work as is (since it should be a straight copy of the author’s algo), so I mainly need to know if/where I’ve made a Lux coding error.

If it’s of any interest, I’ve been training using the first 2.6 million events from the dataset (as per the paper) and validating with the first 100,000 after that. The data is “normalized and scaled” in the way the authors’ used (it comes that way from UCI). I’ve tried a couple different tricks with the data that haven’t worked (e.g. balanced the classes and/or implemented alternating homogeneous batches).

I appreciate any and all help.

1 Like