I am trying to learn how to use Flux.stop()
in a callback. My problem is that as far as I can tell, my call to Flux.stop()
has litterally no effect - see the following example:
The setup:
using Flux
x_train = randn(Float32, 28, 28, 1, 10_000)
labels_train = Flux.onehotbatch(rand(0:9, 10_000), 0:9)
loader_train = Flux.DataLoader((x_train, labels_train), batchsize=100)
model = Chain(
Conv((5, 5), 1=>8, pad=2, stride=2, relu),
Conv((3, 3), 8=>16, pad=1, stride=2, relu),
Conv((3, 3), 16=>32, pad=1, stride=2, relu),
GlobalMeanPool(),
flatten,
Dense(32, 10),
softmax
) # 6_346 parameters
loss = (x, y)->Flux.Losses.logitcrossentropy(model(x), y)
opt = ADAM()
parameters = params(model)
time_limit = 1 # Seconds
With that setup, lets try looping over the dataloader:
times = Dict(:initial => time(), :elapsed => time())
for (xs, labels) in loader_train
Flux.train!(loss, parameters, [(xs, labels)], opt, cb = function cb()
global times[:elapsed] = time() - times[:initial]
"Time elapsed = $(round(times[:elapsed], digits=2))" |> println
if times[:elapsed] > time_limit
println("Trained for $(round(times[:elapsed], digits=2)), stopping")
Flux.stop()
end
end
)
end
This produces the following output:
Time elapsed = 0.12
Time elapsed = 0.13
Time elapsed = 0.14
Time elapsed = 0.15
Time elapsed = 0.16
Time elapsed = 0.17
Time elapsed = 0.18
Time elapsed = 0.19
Time elapsed = 0.2
Time elapsed = 0.21
Time elapsed = 0.22
Time elapsed = 0.23
Time elapsed = 0.24
Time elapsed = 0.25
Time elapsed = 0.26
Time elapsed = 0.27
Time elapsed = 0.28
Time elapsed = 0.29
Time elapsed = 0.3
Time elapsed = 0.31
Time elapsed = 0.32
Time elapsed = 0.33
Time elapsed = 0.34
Time elapsed = 0.35
Time elapsed = 0.36
Time elapsed = 0.37
Time elapsed = 0.38
Time elapsed = 0.38
Time elapsed = 0.39
Time elapsed = 0.4
Time elapsed = 0.41
Time elapsed = 0.42
Time elapsed = 0.43
Time elapsed = 0.44
Time elapsed = 0.45
Time elapsed = 0.45
Time elapsed = 0.46
Time elapsed = 0.47
Time elapsed = 0.48
Time elapsed = 0.49
Time elapsed = 0.5
Time elapsed = 0.51
Time elapsed = 0.52
Time elapsed = 0.53
Time elapsed = 0.54
Time elapsed = 0.55
Time elapsed = 0.56
Time elapsed = 0.57
Time elapsed = 0.58
Time elapsed = 0.58
Time elapsed = 0.59
Time elapsed = 0.6
Time elapsed = 0.61
Time elapsed = 0.62
Time elapsed = 0.63
Time elapsed = 0.64
Time elapsed = 0.65
Time elapsed = 0.66
Time elapsed = 0.67
Time elapsed = 0.67
Time elapsed = 0.68
Time elapsed = 0.69
Time elapsed = 0.7
Time elapsed = 0.71
Time elapsed = 0.84
Time elapsed = 0.85
Time elapsed = 0.86
Time elapsed = 0.86
Time elapsed = 0.87
Time elapsed = 0.88
Time elapsed = 0.89
Time elapsed = 0.9
Time elapsed = 0.91
Time elapsed = 0.92
Time elapsed = 0.93
Time elapsed = 0.94
Time elapsed = 0.95
Time elapsed = 0.95
Time elapsed = 0.96
Time elapsed = 0.97
Time elapsed = 0.98
Time elapsed = 0.99
Time elapsed = 1.0
Time elapsed = 1.01
Trained for 1.01, stopping
Time elapsed = 1.02
Trained for 1.02, stopping
Time elapsed = 1.03
Trained for 1.03, stopping
Time elapsed = 1.04
Trained for 1.04, stopping
Time elapsed = 1.05
Trained for 1.05, stopping
Time elapsed = 1.05
Trained for 1.05, stopping
Time elapsed = 1.06
Trained for 1.06, stopping
Time elapsed = 1.07
Trained for 1.07, stopping
Time elapsed = 1.08
Trained for 1.08, stopping
Time elapsed = 1.09
Trained for 1.09, stopping
Time elapsed = 1.1
Trained for 1.1, stopping
Time elapsed = 1.11
Trained for 1.11, stopping
Time elapsed = 1.12
Trained for 1.12, stopping
Time elapsed = 1.13
Trained for 1.13, stopping
Time elapsed = 1.13
Trained for 1.13, stopping
Time elapsed = 1.14
Trained for 1.14, stopping
Time elapsed = 1.15
Trained for 1.15, stopping
So no stopping occurs - it seems like nothing changes when the elapsed time goes over the timelimit.
So lets try a single pass over all the data:
times = Dict(:initial => time(), :elapsed => time())
for _ in 1:3 # Repeat the for loop below 3 times
for (xs, labels) in [(x_train[:, :, :, 1:10_000], labels_train[:, 1:10_000])]
t₀ = time()
Flux.train!(loss, parameters, [(xs, labels)], opt, cb = function cb()
global times[:elapsed] = time() - times[:initial]
"Time elapsed = $(round(times[:elapsed], digits=2))" |> println
if times[:elapsed] > time_limit
println("Trained for $(round(times[:elapsed], digits=2)), stopping")
Flux.stop()
end
end
)
end
end
This produces the following output:
Time elapsed = 0.93
Time elapsed = 1.82
Trained for 1.82, stopping
Time elapsed = 2.6
Trained for 2.6, stopping
Again, no stopping occurs - after going over 1 second trained, nothing changes.
So it seems like the callback prints things just fine, but at no point actually stops the training. What is going wrong?