I am trying to do supervised learning with Flux, as a way to add a regularization term to the loss function referring to the official documentation.
I have tried two ways to add a regularization term to the loss function, but adding a penalty directly to the loss function does not finish, whereas using WeightDecay(0.001) the calculation takes only ~2sec. Why does this happen?
Here is the fast code:
using Flux, Random, Statistics, Distributions, LinearAlgebra, Plots, ProgressMeter
function data_generation(;N = 1000)
μ1, μ2 = [2.0f0, 0.0f0], [-2.0f0, 0.0f0]
σ1, σ2 = 0.5f0, 0.5f0
data1 = rand(MvNormal(μ1, σ1*I), N)
data2 = rand(MvNormal(μ2, σ2*I), N)
labels1 = fill(0, N)
labels2 = fill(1, N)
plt = scatter(data1[1, :], data1[2, :], label="class1")
scatter!(plt, data2[1, :], data2[2, :], label="class2")
display(plt)
# shuffle
data = hcat(data1, data2)
labels = hcat(labels1, labels2)
idxs = shuffle(1:2N)
data = data[:, idxs]
labels = labels[idxs]
# split data
train_data = [(data[:, i], labels[i]) for i in 1:300]
test_data = [(data[:, i], labels[i]) for i in 301:N]
return train_data, test_data
end
# validation
function accuracy(data, model)
correct = 0
total = 0
for (x, y) in data
pred = model(x)[1]
if round(Int64, pred) == y
correct += 1
end
total += 1
end
return correct / total
end
function train(;epochs = 1000)
# data
train_data, test_data = data_generation()
# model and loss
model = Chain(Dense(2, 1, σ, init=Flux.glorot_normal), sigmoid)
loss(y_hat, y) = sum((y_hat .- y).^2)
println("Initial accuracy: ", accuracy(train_data, model))
println("Initial accuracy: ", accuracy(test_data, model))
# training
train_accuracy = zeros(epochs)
test_accuracy = zeros(epochs)
opt = Flux.setup(OptimiserChain(WeightDecay(0.001), Adam()), model) ## fast
# opt = Flux.setup(Adam(), model) ## too slow
pen_l2(x::AbstractArray) = sum(abs2, x)/2
@showprogress for epoch in 1:epochs
for data in train_data
input, label = data
grads = Flux.gradient(model) do m
result = m(input)
#penalty = sum(pen_l2, Flux.params(m))
loss(result, label) #+ 0.001f0 * penalty
end
Flux.update!(opt, model, grads[1])
end
train_accuracy[epoch] = accuracy(train_data, model)
test_accuracy[epoch] = accuracy(test_data, model)
end
return train_accuracy, test_accuracy
end
@time train_accuracy, test_accuracy = train()
And here is the slow code:
using Flux, Random, Statistics, Distributions, LinearAlgebra, Plots, ProgressMeter
function train(;epochs = 1000)
# data
train_data, test_data = data_generation()
# model and loss
model = Chain(Dense(2, 1, σ, init=Flux.glorot_normal), sigmoid)
loss(y_hat, y) = sum((y_hat .- y).^2)
println("Initial accuracy: ", accuracy(train_data, model))
println("Initial accuracy: ", accuracy(test_data, model))
# training
train_accuracy = zeros(epochs)
test_accuracy = zeros(epochs)
#opt = Flux.setup(OptimiserChain(WeightDecay(0.001), Adam()), model) ## fast
opt = Flux.setup(Adam(), model) ## too slow
pen_l2(x::AbstractArray) = sum(abs2, x)/2
@showprogress for epoch in 1:epochs
for data in train_data
input, label = data
grads = Flux.gradient(model) do m
result = m(input)
penalty = sum(pen_l2, Flux.params(m))
loss(result, label) + 0.001f0 * penalty
end
Flux.update!(opt, model, grads[1])
end
train_accuracy[epoch] = accuracy(train_data, model)
test_accuracy[epoch] = accuracy(test_data, model)
end
return train_accuracy, test_accuracy
end
@time train_accuracy, test_accuracy = train()
Thank you in advance.