Prefetch data in a separate thread (or why adding a single print leads to a 2x performance)

I am training a neural network where some significant time is spent on the data preparation on the CPU (data augmentation). This results in a relatively low GPU utilization overall because the GPU has to wait for the data. I want to make an iterator which prefetches the next mini-batch while the GPU is updating the neural network.

In the code below I have a minimal working example of a iterator PrefetchDataIter which prefetches the data on a separate thread. Unfortunately, this task is not executed (function get below) while the main loop is sleeping (simulating work on the GPU). The get functions takes about 2 seconds on my computer and the sleep time is 2 seconds too. For 10 iterations the total time is about 40 seconds. However if I add a println to this function, the task gets called during the sleep function and the code is about twice as fast.

Is there a way to tell Julia, now it is a good time to run any pending tasks?

I use Knet, but the code showing this issue is independent of any neural network package.

Program output without print:

start sleep
end sleep
[ repeated 9 times ]
 41.415683 seconds (52.30 k allocations: 15.053 GiB, 1.48% gc time, 0.13% compilation time)

Program output with print:

start work
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
end sleep
 24.652950 seconds (52.36 k allocations: 15.053 GiB, 3.21% gc time, 0.27% compilation time)

Notice that “start work” is between “start sleep” and “end sleep”.
I have run the code multiple times to exclude the compilation time.

I am using Julia 1.6.0 on Linux (official binaries).
Below is the full code:

import Base.iterate


# some computation which take about 2 seconds
# for n = 50 and size(A) = (2000,2000)
function get(A,n)
    # adding this print has a surprising impact on the scheduling of the task
#    println("start work")
    for i = 1:n
        A = A*A';
        A = A/maximum(A)
    end
    return A
end

# "slow" iterator for 10 matrices of size sz
struct DataIter
    sz::NTuple{2,Int}
end

function Base.iterate(d::DataIter,index=0)
    if index == 10
        return nothing
    end
    A = randn(Float32,d.sz)
    B = get(A,50);
    return (B,index+1)
end


# Iterator to prefetch data on a separate thread
mutable struct PrefetchDataIter{T}
    iter::T
    task::Union{Task,Nothing}
end

PrefetchDataIter(iter) = PrefetchDataIter(iter,nothing)

function Base.iterate(d::PrefetchDataIter,args...)
    if d.task == nothing
        out = iterate(d.iter,args...)
    else
        out = fetch(d.task)
    end

    if out == nothing
        return nothing
    else
        next,state = out

        d.task = Threads.@spawn iterate(d.iter,state)
        return (next,state)
    end
end

sz = (2000,2000)

@time for B in PrefetchDataIter(DataIter(sz))
    println("start sleep")
    sleep(2) # simulating work on GPU
    println("end sleep")
end

My guess is that, since everything in the main loop is I/O (print and sleep), the tasks in iterate are actually often scheduled on the main thread that the main loop is running (which is not doing anything from the scheduler’s perspective). However, it means that sleep(2) can’t finish at the time it wants to finish unless other task yields; e.g., print something.

This is just speculation, but one way to work around this may be to yield often/sometimes in DataIter. For example

function get(A,n)
    for i = 1:n
        A = A*A';
        A = A/maximum(A)
        yield()  # add this
    end
    return A
end

By the way, I think DataIter can be written much more simply if you use Channel. See the first example in Pattern for managing thread local storage? - #2 by tkf.

1 Like

Thanks a lot for your insight! You are right that the yield() function call does work here to get the tasks run in parallel. I am still a bit puzzled why julia does not schedule to work on different threads (I start julia with 2 threads julia -t 2), but I will use your approach.

Good to know inserting yield fixed the problem. I agree that it’s a performance pitfall of the current scheduler. But the upside of making your function “cooperative” (= add some yields) is that it’d perform well (I think) even if you create nthreads() tasks in PrefetchDataIter. Something like this is required for maximizing the throughput on the CPU side, at least for now.

1 Like

I can’t give you an answer to the original question, but have you seen DataLoaders.jl? It is made for prefetching batches on background threads and might be useful to you.

3 Likes

Thanks a lot Lorzenzo! I was indeed unaware of DataLoaders.jl which looks indeed precisely what I need. Thanks a lot for sharing your work!