Interrupting Flux.jl training without killing kernel

Hello. My informal Flux workflow involves running training routines such as the following in the REPL or Jupyter cells:

# example training loop
function train_discriminator()
    showloss = Flux.throttle(2) do batch, loss
        accuracy = discriminator_accuracy()
        @printf "batch %d: loss = %.5f, accuracy = %.1f%%\n" batch loss 100accuracy
        loss < 0.1 && Flux.stop()
    end
    for (i, real) in enumerate(realbatches)
        fake = G(latentpoints(batchsize))
        loss, ∇D = Flux.withgradient(θD) do
            binarycrossentropy(D(real), 1) + binarycrossentropy(D(fake), 0)
        end
        showloss(i, loss)
        Flux.update!(opt, θD, ∇D)
    end
end

I often want to interrupt the training with Ctrl+C (or I, I in notebooks) for various reasons…

  • noticing mistakes after running code
  • exiting training early
  • getting impatient if training takes too long
  • etc…

However, most of the time, interrupting the evaluation of train_discriminator() will crash my session:

julia> train_discriminator()
batch 1: loss = 1.37961, accuracy = 64.1%
batch 4: loss = 1.34973, accuracy = 64.8%
batch 7: loss = 1.31434, accuracy = 65.4%
batch 10: loss = 1.26995, accuracy = 66.0%
batch 13: loss = 1.23204, accuracy = 70.2%
^Cfatal: error thrown and no exception handler available.
InterruptException()
sigatomic_end at ./c.jl:437 [inlined]
task_done_hook at ./task.jl:542
jl_apply_generic at /Users/jollywatt/.julia/juliaup/julia-1.7.2+0~x64/lib/julia/libjulia-internal.1.7.dylib (unknown line)
jl_finish_task at /Users/jollywatt/.julia/juliaup/julia-1.7.2+0~x64/lib/julia/libjulia-internal.1.7.dylib (unknown line)
start_task at /Users/jollywatt/.julia/juliaup/julia-1.7.2+0~x64/lib/julia/libjulia-internal.1.7.dylib (unknown line)

[exits to shell / restarts notebook]

Loosing the Julia session and having to reload everything (and retrain — if not regularly saving models using BSON) is a pain and slows my workflow down.

Surrounding the critical loops in try … catch … end blocks does not seem to help.


How can I prevent my Julia kernel from dying when interrupting these loops? Is there a better way of interactively stopping Flux training loops?

2 Likes

There are issues with properly handling SIGINT. See for example Ctrl-C does not work when running multi-threaded code · Issue #35524 · JuliaLang/julia · GitHub

Is this a multi-threading issue I wonder. Does it still crash when you start with julia --threads=1 ?

Also, is it possible to turn this into a self-contained example?

Using single-threaded mode julia --threads=1 doesn’t seem to help.
I’ve struggled to make a MWE, since the simple toy examples I wrote don’t seem to have any problems with interruption.

However, with this watered-down version of my original code (a GAN for the MNIST database, e.g., using keras or Flux), pressing ⌃C will usually kill the session:

Not-so-minimal working example

Running on julia version 1.7.2 / Flux v0.13.0.

using MLDatasets: MNIST
using Flux

function loaddata(hp)
	images = reshape(MNIST.traintensor(Float32), 28, 28, 1, :)
	Flux.Data.DataLoader(images; hp.batchsize)
end

Discriminator(hp) = Chain(
    Conv((3,3), 1=>32, leakyrelu; stride=2, pad=2), # 14×14
    MaxPool((2,2)), # 7×7
    Conv((3,3), 32=>64, leakyrelu; stride=2, pad=1), # 4×4
    MaxPool((2,2)), # 2×2

    Flux.flatten, # 256

    Dense(256 => 2),
    softmax,
)

Generator(hp) = Chain(
    Dense(hp.latentdim => 7*7*128),
    x -> leakyrelu.(x, 0.2),

    x -> reshape(x, 7, 7, 128, :),
    
    ConvTranspose((4, 4), 128=>128; stride=2, pad=SamePad()),
    x -> leakyrelu.(x, 0.2),

    ConvTranspose((4, 4), 128=>128; stride=2, pad=SamePad()),
    x -> leakyrelu.(x, 0.2),
    
    ConvTranspose((7, 7), 128=>1, σ; pad=SamePad()),
)

latentpoints(hp, n=hp.batchsize) = randn(Float32, hp.latentdim, n)

function train_discriminator_step!(real, G, D, opt, hp)
	fake = G(latentpoints(hp))
	θ = Flux.params(D)
	σH = Flux.Losses.logitbinarycrossentropy
	loss, grad = Flux.withgradient(θ) do
		σH(D(real), 1f0) + σH(D(fake), 0f0)
	end
	Flux.update!(opt, θ, grad)
	return loss
end

function train_discriminator!(realbatches, G, D, opt, hp)
	showprog = Flux.throttle(2) do loss, i
		println("batch $i: $loss")
	end
	for (i, real) in enumerate(realbatches)
		loss = train_discriminator_step!(real, G, D, opt, hp)
		showprog(loss, i)
	end
end


hp = (; latentdim=100, batchsize=128)
batches = loaddata(hp)

D = Discriminator(hp)
G = Generator(hp)

opt = ADAM()

train_discriminator!(batches, G, D, opt, hp) # try interrupting this

You could try inserting yield into your training loop to give the scheduler a chance to butt in. If that still doesn’t work, sleep might.

I haven’t been able to prevent Julia from crashing by using Base.yield() in the training loop… I still experience crashes 75% of the time. (Tested on macOS and Ubuntu.)

IIRC Jupyter will forcefully kill the kernel if you interrupt multiple times in quick succession. Likewise for the REPL and Ctrl+C if a sigint is already being processed. So you may have to give it a few seconds, and if that doesn’t work then more needs to be done to make the loop interruptible (what I’m not sure, but do try sleep if you haven’t already).

1 Like

You are quite right. Unfortunately, what I am experiencing occurs immediately after even a single Ctrl+C press (at least, in the REPL). Adding a sleep step in the inner loop also doesn’t seem to work…

I’m curious as to whether this is specific to my environment. Have you tried running the MWE?

I tried it in the REPL a couple of times. Once I hit fatal: error thrown and no exception handler available. Curiously, the session didn’t crash. Here’s the full

stack trace
julia> train_discriminator!(batches, G, D, opt, hp) # try interrupting this
batch 1: 1.448166
batch 5: 1.4481609
batch 9: 1.4481605
batch 13: 1.4481593
batch 17: 1.4481587
batch 21: 1.4481585
batch 25: 1.4481581
^Cfatal: error thrown and no exception handler available.
InterruptException()
sigatomic_end at ./c.jl:437 [inlined]
task_done_hook at ./task.jl:542
jl_apply_generic at /Applications/Julia-1.7.app/Contents/Resources/julia/lib/julia/libjulia-internal.1.dylib (unknown line)
jl_finish_task at /Applications/Julia-1.7.app/Contents/Resources/julia/lib/julia/libjulia-internal.1.dylib (unknown line)
start_task at /Applications/Julia-1.7.app/Contents/Resources/julia/lib/julia/libjulia-internal.1.dylib (unknown line)
┌ Warning: temp cleanup
│   exception =
│    schedule: Task not runnable
│    Stacktrace:
│      [1] error(s::String)
│        @ Base ./error.jl:33
│      [2] enq_work(t::Task)
│        @ Base ./task.jl:628
│      [3] yield
│        @ ./task.jl:739 [inlined]
│      [4] yield
│        @ ./task.jl:737 [inlined]
│      [5] Channel{Tuple{String, Vector{String}, Vector{String}}}(func::Base.Filesystem.var"#31#34"{String}, size::Int64; taskref::Nothing, spawn::Bool)
│        @ Base ./channels.jl:138
│      [6] Channel (repeats 2 times)
│        @ ./channels.jl:131 [inlined]
│      [7] #walkdir#30
│        @ ./file.jl:953 [inlined]
│      [8] prepare_for_deletion(path::String)
│        @ Base.Filesystem ./file.jl:497
│      [9] temp_cleanup_purge(; force::Bool)
│        @ Base.Filesystem ./file.jl:532
│     [10] (::Base.var"#838#839")()
│        @ Base ./initdefs.jl:329
│     [11] _atexit()
│        @ Base ./initdefs.jl:350
└ @ Base.Filesystem file.jl:537
ERROR: TaskFailedException

    nested task error: schedule: Task not runnable
    Stacktrace:
      [1] error(s::String)
        @ Base ./error.jl:33
      [2] schedule(t::Task, arg::Any; error::Bool)
        @ Base ./task.jl:697
      [3] schedule
        @ ./task.jl:697 [inlined]
      [4] uv_writecb_task(req::Ptr{Nothing}, status::Int32)
        @ Base ./stream.jl:1110
      [5] process_events
        @ ./libuv.jl:104 [inlined]
      [6] wait()
        @ Base ./task.jl:838
      [7] wait(c::Base.GenericCondition{Base.Threads.SpinLock})
        @ Base ./condition.jl:123
      [8] _wait(t::Task)
        @ Base ./task.jl:293
      [9] wait
        @ ./task.jl:332 [inlined]
     [10] threading_run(func::Function)
        @ Base.Threads ./threadingconstructs.jl:38
     [11] macro expansion
        @ ./threadingconstructs.jl:97 [inlined]
     [12] ∇conv_data_im2col!(dx::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, dy::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3}; col::Array{Float32, 3}, alpha::Float32, beta::Float32)
        @ NNlib ~/.julia/packages/NNlib/hydo3/src/impl/conv_im2col.jl:146
     [13] ∇conv_data_im2col!
        @ ~/.julia/packages/NNlib/hydo3/src/impl/conv_im2col.jl:125 [inlined]
     [14] (::NNlib.var"#271#275"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
        @ NNlib ./threadingconstructs.jl:178
Stacktrace:
  [1] sync_end(c::Channel{Any})
    @ Base ./task.jl:381
  [2] macro expansion
    @ ./task.jl:400 [inlined]
  [3] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::Array{Float32, 5}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/hydo3/src/conv.jl:222
  [4] ∇conv_data!
    @ ~/.julia/packages/NNlib/hydo3/src/conv.jl:211 [inlined]
  [5] ∇conv_data!(y::Array{Float32, 4}, x::Array{Float32, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/hydo3/src/conv.jl:145
  [6] ∇conv_data!
    @ ~/.julia/packages/NNlib/hydo3/src/conv.jl:145 [inlined]
  [7] #∇conv_data#198
    @ ~/.julia/packages/NNlib/hydo3/src/conv.jl:99 [inlined]
  [8] ∇conv_data
    @ ~/.julia/packages/NNlib/hydo3/src/conv.jl:98 [inlined]
  [9] (::ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(x::Array{Float32, 4})
    @ Flux ~/.julia/packages/Flux/18YZE/src/layers/conv.jl:286
 [10] macro expansion
    @ ~/.julia/packages/Flux/18YZE/src/layers/basic.jl:53 [inlined]
 [11] applychain(layers::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, var"#37#41", var"#38#42", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#39#43", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#40#44", ConvTranspose{2, 4, typeof(σ), Array{Float32, 4}, Vector{Float32}}}, x::Matrix{Float32})
    @ Flux ~/.julia/packages/Flux/18YZE/src/layers/basic.jl:53
 [12] Chain
    @ ~/.julia/packages/Flux/18YZE/src/layers/basic.jl:51 [inlined]
 [13] train_discriminator_step!(real::Array{Float32, 4}, G::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, var"#37#41", var"#38#42", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#39#43", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#40#44", ConvTranspose{2, 4, typeof(σ), Array{Float32, 4}, Vector{Float32}}}}, D::Chain{Tuple{Conv{2, 4, typeof(leakyrelu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 4, typeof(leakyrelu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}}, opt::ADAM, hp::NamedTuple{(:latentdim, :batchsize), Tuple{Int64, Int64}})
    @ Main ./REPL[300]:2
 [14] train_discriminator!(realbatches::MLUtils.DataLoader{Array{Float32, 4}, Random._GLOBAL_RNG}, G::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, var"#37#41", var"#38#42", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#39#43", ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, var"#40#44", ConvTranspose{2, 4, typeof(σ), Array{Float32, 4}, Vector{Float32}}}}, D::Chain{Tuple{Conv{2, 4, typeof(leakyrelu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, Conv{2, 4, typeof(leakyrelu), Array{Float32, 4}, Vector{Float32}}, MaxPool{2, 4}, typeof(Flux.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, typeof(softmax)}}, opt::ADAM, hp::NamedTuple{(:latentdim, :batchsize), Tuple{Int64, Int64}})
    @ Main ./REPL[301]:6
 [15] top-level scope
    @ REPL[309]:1
 [16] top-level scope
    @ ~/.julia/packages/CUDA/qAl31/src/initialization.jl:52