M. learning with regularization using Flux is too slow?

I am trying to do supervised learning with Flux, as a way to add a regularization term to the loss function referring to the official documentation.

I have tried two ways to add a regularization term to the loss function, but adding a penalty directly to the loss function does not finish, whereas using WeightDecay(0.001) the calculation takes only ~2sec. Why does this happen?

Here is the fast code:

using Flux, Random, Statistics, Distributions, LinearAlgebra, Plots,  ProgressMeter

function data_generation(;N = 1000)
    μ1, μ2 = [2.0f0, 0.0f0], [-2.0f0, 0.0f0]  
    σ1, σ2 = 0.5f0, 0.5f0  

    data1 = rand(MvNormal(μ1, σ1*I), N)
    data2 = rand(MvNormal(μ2, σ2*I), N)
    labels1 = fill(0, N)
    labels2 = fill(1, N)

    plt = scatter(data1[1, :], data1[2, :], label="class1")
    scatter!(plt, data2[1, :], data2[2, :], label="class2")
    display(plt)

    # shuffle
    data = hcat(data1, data2)
    labels = hcat(labels1, labels2)
    idxs = shuffle(1:2N)
    data = data[:, idxs]
    labels = labels[idxs]

    # split data
    train_data = [(data[:, i], labels[i]) for i in 1:300]
    test_data = [(data[:, i], labels[i]) for i in 301:N]

    return train_data, test_data
end


# validation
function accuracy(data, model) 
    correct = 0
    total = 0
    for (x, y) in data
        pred = model(x)[1]
        if round(Int64, pred) == y
            correct += 1
        end
        total += 1
    end

    return correct / total
end


function train(;epochs = 1000)
    # data
    train_data, test_data = data_generation() 

    # model and loss
    model = Chain(Dense(2, 1, σ, init=Flux.glorot_normal), sigmoid) 
    loss(y_hat, y) = sum((y_hat .- y).^2) 
    println("Initial accuracy: ", accuracy(train_data, model))
    println("Initial accuracy: ", accuracy(test_data, model))

    # training
    train_accuracy = zeros(epochs)
    test_accuracy = zeros(epochs)
    opt = Flux.setup(OptimiserChain(WeightDecay(0.001), Adam()), model)  ## fast
    # opt = Flux.setup(Adam(), model)  ## too slow

    pen_l2(x::AbstractArray) = sum(abs2, x)/2
    @showprogress for epoch in 1:epochs
        for data in train_data
            input, label = data
            grads = Flux.gradient(model) do m
                result = m(input)
                #penalty = sum(pen_l2, Flux.params(m))
                loss(result, label) #+ 0.001f0 * penalty
            end

            Flux.update!(opt, model, grads[1])
        end
        train_accuracy[epoch] = accuracy(train_data, model)
        test_accuracy[epoch] = accuracy(test_data, model)
    end

    return train_accuracy, test_accuracy
end

@time train_accuracy, test_accuracy = train()

And here is the slow code:

using Flux, Random, Statistics, Distributions, LinearAlgebra, Plots, ProgressMeter

function train(;epochs = 1000)
    # data
    train_data, test_data = data_generation() 

    # model and loss
    model = Chain(Dense(2, 1, σ, init=Flux.glorot_normal), sigmoid) 
    loss(y_hat, y) = sum((y_hat .- y).^2) 
    println("Initial accuracy: ", accuracy(train_data, model))
    println("Initial accuracy: ", accuracy(test_data, model))

    # training
    train_accuracy = zeros(epochs)
    test_accuracy = zeros(epochs)
    #opt = Flux.setup(OptimiserChain(WeightDecay(0.001), Adam()), model)  ## fast
    opt = Flux.setup(Adam(), model)  ## too slow

    pen_l2(x::AbstractArray) = sum(abs2, x)/2
    @showprogress for epoch in 1:epochs
        for data in train_data
            input, label = data
            grads = Flux.gradient(model) do m
                result = m(input)
                penalty = sum(pen_l2, Flux.params(m))
                loss(result, label) + 0.001f0 * penalty
            end

            Flux.update!(opt, model, grads[1])
        end
        train_accuracy[epoch] = accuracy(train_data, model)
        test_accuracy[epoch] = accuracy(test_data, model)
    end

    return train_accuracy, test_accuracy
end

@time train_accuracy, test_accuracy = train()

Thank you in advance.

1 Like

I have the same question!
This is relevant especially when the regularization is not on all the parameters, but only for specific layer, hence the WeightDecay approach in the optimizer is not adapted (?).

It seems that Flux.params(m) is taking very long. Does anyone have a better way to regularized?

The easiest solution here is to use the code in RFC: add `total` by mcabbott · Pull Request #57 · FluxML/Optimisers.jl · GitHub for a faster way to sum parameters. Layer specific regularization can be done by calling total on each piece and adding the results together.

Thanks for the ref.

If I just want to apply regularization to weigths of specific layers, can I just use

weights_regularizer = lw * Flux.Functors.fmap(pen_l2, layer.weight) # kind of similar to what is done in total
# or even simpler but is it correct?
weights_regularizer = lw * pen_l2(layer.weight) 

I think this is what I was looking for no need to call Param or fmap and it is fast.

However, the Flux.modules call is quite long in my example! (but not quite as Params)

Ho no! In fact

ERROR: Flux.modules is not at present differentiable, sorry

So why is it in the doc as a solution for regularization? That is a bit confusing.

Flux.modules was created in the days when implicit parameters (Flux.params) was the recommended way to take gradients. It looks like in the switch over the explicit parameters, the Flux.modules was never properly updated. It should be deprecated and removed from the docs.

The second option you mention can work as long as layer refers to the argument being differentiated.

function loss(m, x, y)
     l = # some loss term
     p = pen_l2(m[1].weight)

    return l + lw * p

Notice that I using m to get the weight.

1 Like