How to use Flux.stop()

I am trying to learn how to use Flux.stop() in a callback. My problem is that as far as I can tell, my call to Flux.stop() has litterally no effect - see the following example:

The setup:

using Flux
x_train = randn(Float32, 28, 28, 1, 10_000)
labels_train = Flux.onehotbatch(rand(0:9, 10_000), 0:9)
loader_train = Flux.DataLoader((x_train, labels_train), batchsize=100)
model = Chain(
    Conv((5, 5), 1=>8, pad=2, stride=2, relu),
    Conv((3, 3), 8=>16, pad=1, stride=2, relu),
    Conv((3, 3), 16=>32, pad=1, stride=2, relu),
    GlobalMeanPool(),
    flatten,
    Dense(32, 10),
    softmax
) # 6_346 parameters
loss = (x, y)->Flux.Losses.logitcrossentropy(model(x), y)
opt = ADAM()
parameters = params(model)

time_limit = 1 # Seconds

With that setup, lets try looping over the dataloader:

times = Dict(:initial => time(), :elapsed => time())
for (xs, labels) in loader_train
    Flux.train!(loss, parameters, [(xs, labels)], opt, cb = function cb()
        global times[:elapsed] = time() - times[:initial]
        "Time elapsed = $(round(times[:elapsed], digits=2))" |> println
        if times[:elapsed] > time_limit
            println("Trained for $(round(times[:elapsed], digits=2)), stopping")
            Flux.stop()
        end
    end
    )
end

This produces the following output:

Time elapsed = 0.12
Time elapsed = 0.13
Time elapsed = 0.14
Time elapsed = 0.15
Time elapsed = 0.16
Time elapsed = 0.17
Time elapsed = 0.18
Time elapsed = 0.19
Time elapsed = 0.2
Time elapsed = 0.21
Time elapsed = 0.22
Time elapsed = 0.23
Time elapsed = 0.24
Time elapsed = 0.25
Time elapsed = 0.26
Time elapsed = 0.27
Time elapsed = 0.28
Time elapsed = 0.29
Time elapsed = 0.3
Time elapsed = 0.31
Time elapsed = 0.32
Time elapsed = 0.33
Time elapsed = 0.34
Time elapsed = 0.35
Time elapsed = 0.36
Time elapsed = 0.37
Time elapsed = 0.38
Time elapsed = 0.38
Time elapsed = 0.39
Time elapsed = 0.4
Time elapsed = 0.41
Time elapsed = 0.42
Time elapsed = 0.43
Time elapsed = 0.44
Time elapsed = 0.45
Time elapsed = 0.45
Time elapsed = 0.46
Time elapsed = 0.47
Time elapsed = 0.48
Time elapsed = 0.49
Time elapsed = 0.5
Time elapsed = 0.51
Time elapsed = 0.52
Time elapsed = 0.53
Time elapsed = 0.54
Time elapsed = 0.55
Time elapsed = 0.56
Time elapsed = 0.57
Time elapsed = 0.58
Time elapsed = 0.58
Time elapsed = 0.59
Time elapsed = 0.6
Time elapsed = 0.61
Time elapsed = 0.62
Time elapsed = 0.63
Time elapsed = 0.64
Time elapsed = 0.65
Time elapsed = 0.66
Time elapsed = 0.67
Time elapsed = 0.67
Time elapsed = 0.68
Time elapsed = 0.69
Time elapsed = 0.7
Time elapsed = 0.71
Time elapsed = 0.84
Time elapsed = 0.85
Time elapsed = 0.86
Time elapsed = 0.86
Time elapsed = 0.87
Time elapsed = 0.88
Time elapsed = 0.89
Time elapsed = 0.9
Time elapsed = 0.91
Time elapsed = 0.92
Time elapsed = 0.93
Time elapsed = 0.94
Time elapsed = 0.95
Time elapsed = 0.95
Time elapsed = 0.96
Time elapsed = 0.97
Time elapsed = 0.98
Time elapsed = 0.99
Time elapsed = 1.0
Time elapsed = 1.01
Trained for 1.01, stopping
Time elapsed = 1.02
Trained for 1.02, stopping
Time elapsed = 1.03
Trained for 1.03, stopping
Time elapsed = 1.04
Trained for 1.04, stopping
Time elapsed = 1.05
Trained for 1.05, stopping
Time elapsed = 1.05
Trained for 1.05, stopping
Time elapsed = 1.06
Trained for 1.06, stopping
Time elapsed = 1.07
Trained for 1.07, stopping
Time elapsed = 1.08
Trained for 1.08, stopping
Time elapsed = 1.09
Trained for 1.09, stopping
Time elapsed = 1.1
Trained for 1.1, stopping
Time elapsed = 1.11
Trained for 1.11, stopping
Time elapsed = 1.12
Trained for 1.12, stopping
Time elapsed = 1.13
Trained for 1.13, stopping
Time elapsed = 1.13
Trained for 1.13, stopping
Time elapsed = 1.14
Trained for 1.14, stopping
Time elapsed = 1.15
Trained for 1.15, stopping

So no stopping occurs - it seems like nothing changes when the elapsed time goes over the timelimit.

So lets try a single pass over all the data:

times = Dict(:initial => time(), :elapsed => time())
for _ in 1:3  # Repeat the for loop below 3 times
    for (xs, labels) in [(x_train[:, :, :, 1:10_000], labels_train[:, 1:10_000])]
        t₀ = time()
        Flux.train!(loss, parameters, [(xs, labels)], opt, cb = function cb()
            global times[:elapsed] = time() - times[:initial]
            "Time elapsed = $(round(times[:elapsed], digits=2))" |> println
            if times[:elapsed] > time_limit
                println("Trained for $(round(times[:elapsed], digits=2)), stopping")
                Flux.stop()
            end
        end
        )
    end
end

This produces the following output:

Time elapsed = 0.93
Time elapsed = 1.82
Trained for 1.82, stopping
Time elapsed = 2.6
Trained for 2.6, stopping

Again, no stopping occurs - after going over 1 second trained, nothing changes.

So it seems like the callback prints things just fine, but at no point actually stops the training. What is going wrong?

I think Flux.stop() is not meant to stop any outer loop around Flux.train!().

But for your example you do not actually need a loop around Flux.train!(). Instead what you can do is to directly provide your dataloader to Flux.train!() (the loop over minibatches happens inside Flux.train!()).
See the minimal example here: Training · Flux

Btw., Flux.train!() and Flux.stop() are no magic. The latter just throws a custom error, that the former catches within the minibatch loop: https://github.com/FluxML/Flux.jl/blob/ef04fda844ea05c45d8cc53448e3b25513a77617/src/optimise/train.jl#L63-L124

So if you replace your training loop (your second code block) by:

Flux.train!(loss, parameters, loader_train, opt, cb = function cb()
     global times[:elapsed] = time() - times[:initial]
     "Time elapsed = $(round(times[:elapsed], digits=2))" |> println
     if times[:elapsed] > time_limit
         println("Trained for $(round(times[:elapsed], digits=2)), stopping")
         Flux.stop()
     end
 end
 )

this produces what you want:

...
Time elapsed = 0.94
Time elapsed = 0.95
Time elapsed = 0.97
Time elapsed = 1.01
Trained for 1.01, stopping
1 Like