Interrupting Flux.jl training without killing kernel

Hello. My informal Flux workflow involves running training routines such as the following in the REPL or Jupyter cells:

# example training loop
function train_discriminator()
    showloss = Flux.throttle(2) do batch, loss
        accuracy = discriminator_accuracy()
        @printf "batch %d: loss = %.5f, accuracy = %.1f%%\n" batch loss 100accuracy
        loss < 0.1 && Flux.stop()
    end
    for (i, real) in enumerate(realbatches)
        fake = G(latentpoints(batchsize))
        loss, ∇D = Flux.withgradient(θD) do
            binarycrossentropy(D(real), 1) + binarycrossentropy(D(fake), 0)
        end
        showloss(i, loss)
        Flux.update!(opt, θD, ∇D)
    end
end

I often want to interrupt the training with Ctrl+C (or I, I in notebooks) for various reasons…

  • noticing mistakes after running code
  • exiting training early
  • getting impatient if training takes too long
  • etc…

However, most of the time, interrupting the evaluation of train_discriminator() will crash my session:

julia> train_discriminator()
batch 1: loss = 1.37961, accuracy = 64.1%
batch 4: loss = 1.34973, accuracy = 64.8%
batch 7: loss = 1.31434, accuracy = 65.4%
batch 10: loss = 1.26995, accuracy = 66.0%
batch 13: loss = 1.23204, accuracy = 70.2%
^Cfatal: error thrown and no exception handler available.
InterruptException()
sigatomic_end at ./c.jl:437 [inlined]
task_done_hook at ./task.jl:542
jl_apply_generic at /Users/jollywatt/.julia/juliaup/julia-1.7.2+0~x64/lib/julia/libjulia-internal.1.7.dylib (unknown line)
jl_finish_task at /Users/jollywatt/.julia/juliaup/julia-1.7.2+0~x64/lib/julia/libjulia-internal.1.7.dylib (unknown line)
start_task at /Users/jollywatt/.julia/juliaup/julia-1.7.2+0~x64/lib/julia/libjulia-internal.1.7.dylib (unknown line)

[exits to shell / restarts notebook]

Loosing the Julia session and having to reload everything (and retrain — if not regularly saving models using BSON) is a pain and slows my workflow down.

Surrounding the critical loops in try … catch … end blocks does not seem to help.


How can I prevent my Julia kernel from dying when interrupting these loops? Is there a better way of interactively stopping Flux training loops?

There are issues with properly handling SIGINT. See for example Ctrl-C does not work when running multi-threaded code · Issue #35524 · JuliaLang/julia · GitHub

Is this a multi-threading issue I wonder. Does it still crash when you start with julia --threads=1 ?

Also, is it possible to turn this into a self-contained example?

Using single-threaded mode julia --threads=1 doesn’t seem to help.
I’ve struggled to make a MWE, since the simple toy examples I wrote don’t seem to have any problems with interruption.

However, with this watered-down version of my original code (a GAN for the MNIST database, e.g., using keras or Flux), pressing ⌃C will usually kill the session:

Not-so-minimal working example

Running on julia version 1.7.2 / Flux v0.13.0.

using MLDatasets: MNIST
using Flux

function loaddata(hp)
	images = reshape(MNIST.traintensor(Float32), 28, 28, 1, :)
	Flux.Data.DataLoader(images; hp.batchsize)
end

Discriminator(hp) = Chain(
    Conv((3,3), 1=>32, leakyrelu; stride=2, pad=2), # 14×14
    MaxPool((2,2)), # 7×7
    Conv((3,3), 32=>64, leakyrelu; stride=2, pad=1), # 4×4
    MaxPool((2,2)), # 2×2

    Flux.flatten, # 256

    Dense(256 => 2),
    softmax,
)

Generator(hp) = Chain(
    Dense(hp.latentdim => 7*7*128),
    x -> leakyrelu.(x, 0.2),

    x -> reshape(x, 7, 7, 128, :),
    
    ConvTranspose((4, 4), 128=>128; stride=2, pad=SamePad()),
    x -> leakyrelu.(x, 0.2),

    ConvTranspose((4, 4), 128=>128; stride=2, pad=SamePad()),
    x -> leakyrelu.(x, 0.2),
    
    ConvTranspose((7, 7), 128=>1, σ; pad=SamePad()),
)

latentpoints(hp, n=hp.batchsize) = randn(Float32, hp.latentdim, n)

function train_discriminator_step!(real, G, D, opt, hp)
	fake = G(latentpoints(hp))
	θ = Flux.params(D)
	σH = Flux.Losses.logitbinarycrossentropy
	loss, grad = Flux.withgradient(θ) do
		σH(D(real), 1f0) + σH(D(fake), 0f0)
	end
	Flux.update!(opt, θ, grad)
	return loss
end

function train_discriminator!(realbatches, G, D, opt, hp)
	showprog = Flux.throttle(2) do loss, i
		println("batch $i: $loss")
	end
	for (i, real) in enumerate(realbatches)
		loss = train_discriminator_step!(real, G, D, opt, hp)
		showprog(loss, i)
	end
end


hp = (; latentdim=100, batchsize=128)
batches = loaddata(hp)

D = Discriminator(hp)
G = Generator(hp)

opt = ADAM()

train_discriminator!(batches, G, D, opt, hp) # try interrupting this

You could try inserting yield into your training loop to give the scheduler a chance to butt in. If that still doesn’t work, sleep might.

I haven’t been able to prevent Julia from crashing by using Base.yield() in the training loop… I still experience crashes 75% of the time. (Tested on macOS and Ubuntu.)

IIRC Jupyter will forcefully kill the kernel if you interrupt multiple times in quick succession. Likewise for the REPL and Ctrl+C if a sigint is already being processed. So you may have to give it a few seconds, and if that doesn’t work then more needs to be done to make the loop interruptible (what I’m not sure, but do try sleep if you haven’t already).

You are quite right. Unfortunately, what I am experiencing occurs immediately after even a single Ctrl+C press (at least, in the REPL). Adding a sleep step in the inner loop also doesn’t seem to work…

I’m curious as to whether this is specific to my environment. Have you tried running the MWE?

I tried it in the REPL a couple of times. Once I hit fatal: error thrown and no exception handler available. Curiously, the session didn’t crash. Here’s the full

stack trace
julia> train_discriminator!(batches, G, D, opt, hp) # try interrupting this
batch 1: 1.448166
batch 5: 1.4481609
batch 9: 1.4481605
batch 13: 1.4481593
batch 17: 1.4481587
batch 21: 1.4481585
batch 25: 1.4481581
^Cfatal: error thrown and no exception handler available.
InterruptException()
sigatomic_end at ./c.jl:437 [inlined]
task_done_hook at ./task.jl:542
jl_apply_generic at /Applications/Julia-1.7.app/Contents/Resources/julia/lib/julia/libjulia-internal.1.dylib (unknown line)
jl_finish_task at /Applications/Julia-1.7.app/Contents/Resources/julia/lib/julia/libjulia-internal.1.dylib (unknown line)
start_task at /Applications/Julia-1.7.app/Contents/Resources/julia/lib/julia/libjulia-internal.1.dylib (unknown line)
┌ Warning: temp cleanup
│   exception =
│    schedule: Task not runnable
│    Stacktrace:
│      [1] error(s::String)
│        @ Base ./error.jl:33
│      [2] enq_work(t::Task)
│        @ Base ./task.jl:628
│      [3] yield
│        @ ./task.jl:739 [inlined]
│      [4] yield
│        @ ./task.jl:737 [inlined]
│      [5] Channel{Tuple{String, Vector{String}, Vector{String}}}(func::Base.Filesystem.var"#31#34"{String}, size::Int64; taskref::Nothing, spawn::Bool)
│        @ Base ./channels.jl:138
│      [6] Channel (repeats 2 times)
│        @ ./channels.jl:131 [inlined]
│      [7] #walkdir#30
│        @ ./file.jl:953 [inlined]
│      [8] prepare_for_deletion(path::String)
│        @ Base.Filesystem ./file.jl:497
│      [9] temp_cleanup_purge(; force::Bool)
│        @ Base.Filesystem ./file.jl:532
│     [10] (::Base.var"#838#839")()
│        @ Base ./initdefs.jl:329
│     [11] _atexit()
│        @ Base ./initdefs.jl:350
└ @ Base.Filesystem file.jl:537
ERROR: TaskFailedException

    nested task error: schedule: Task not runnable
    Stacktrace:
      [1] error(s::String)
        @ Base ./error.jl:33
      [2] schedule(t::Task, arg::Any; error::Bool)
        @ Base ./task.jl:697
      [3] schedule
        @ ./task.jl:697 [inlined]
      [4] uv_writecb_task(req::Ptr{Nothing}, status::Int32)
        @ Base ./stream.jl:1110
      [5] process_events
        @ ./libuv.jl:104 [inlined]
      [6] wait()
        @ Base ./task.jl:838
      [7] wait(c::Base.GenericCondition{Base.Threads.SpinLock})
        @ Base ./condition.jl:123
      [8] _wait(t::Task)
        @ Base ./task.jl:293
      [9] wait
        @ ./task.jl:332 [inlined]
     [10] threading_run(func::Function)
        @ Base.Threads ./threadingconstructs.jl:38
     [11] macro expansion
        @ ./threadingconstructs.jl:97 [inlined]
     [12] ∇conv_data_im2col!(dx::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, dy::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3}; col::Array{Float32, 3}, alpha::Float32, beta::Float32)
        @ NNlib ~/.julia/packages/NNlib/hydo3/src/impl/conv_im2col.jl:146
     [13] ∇conv_data_im2col!
        @ ~/.julia/packages/NNlib/hydo3/src/impl/conv_im2col.jl:125 [inlined]
     [14] (::NNlib.var"#271#275"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
        @ NNlib ./threadingconstructs.jl:178
Stacktrace:
  [1] sync_end(c::Channel{Any})
    @ Base ./task.jl:381
  [2] macro expansion
    @ ./task.jl:400 [inlined]
  [3] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::Array{Float32, 5}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/hydo3/src/conv.jl:222
  [4] ∇conv_data!
    @ ~/.julia/packages/NNlib/hydo3/src/conv.jl:211 [inlined]
  [5] ∇conv_data!(y::Array{Float32, 4}, x::Array{Float32, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/hydo3/src/conv.jl:145
  [6] ∇conv_data!
    @ ~/.julia/packages/NNlib/hydo3/src/conv.jl:145 [inlined]
  [7] #∇conv_data#198
    @ ~/.julia/packages/NNlib/hydo3/src/conv.jl:99 [inlined]
  [8] ∇conv_data
    @ ~/.julia/packages/NNlib/hydo3/src/conv.jl:98 [inlined]
  [9] (::ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(x::Array{Float32, 4})
    @ Flux ~/.julia/packages/Flux/18YZE/src/layers/conv.jl:286
 [10] macro expansion
    @ ~/.julia/packages/Flux/18YZE/src/layers/basic.jl:53 [inlined]
 [11] applychain(layers::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, var"#37#41", var"#38#42", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#39#43", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#40#44", ConvTranspose{2, 4, typeof(σ), Array{Float32, 4}, Vector{Float32}}}, x::Matrix{Float32})
    @ Flux ~/.julia/packages/Flux/18YZE/src/layers/basic.jl:53
 [12] Chain
    @ ~/.julia/packages/Flux/18YZE/src/layers/basic.jl:51 [inlined]
 [13] train_discriminator_step!(real::Array{Float32, 4}, G::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, var"#37#41", var"#38#42", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#39#43", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#40#44", ConvTranspose{2, 4, typeof(σ), Array{Float32, 4}, Vector{Float32}}}}, D::Chain{Tuple{Conv{2, 4, typeof(leakyrelu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 4, typeof(leakyrelu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}}, opt::ADAM, hp::NamedTuple{(:latentdim, :batchsize), Tuple{Int64, Int64}})
    @ Main ./REPL[300]:2
 [14] train_discriminator!(realbatches::MLUtils.DataLoader{Array{Float32, 4}, Random._GLOBAL_RNG}, G::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, var"#37#41", var"#38#42", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#39#43", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#40#44", ConvTranspose{2, 4, typeof(σ), Array{Float32, 4}, Vector{Float32}}}}, D::Chain{Tuple{Conv{2, 4, typeof(leakyrelu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 4, typeof(leakyrelu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}}, opt::ADAM, hp::NamedTuple{(:latentdim, :batchsize), Tuple{Int64, Int64}})
    @ Main ./REPL[301]:6
 [15] top-level scope
    @ REPL[309]:1
 [16] top-level scope
    @ ~/.julia/packages/CUDA/qAl31/src/initialization.jl:52