I am training a neural network where some significant time is spent on the data preparation on the CPU (data augmentation). This results in a relatively low GPU utilization overall because the GPU has to wait for the data. I want to make an iterator which prefetches the next mini-batch while the GPU is updating the neural network.
In the code below I have a minimal working example of a iterator PrefetchDataIter
which prefetches the data on a separate thread. Unfortunately, this task is not executed (function get
below) while the main loop is sleeping (simulating work on the GPU). The get functions takes about 2 seconds on my computer and the sleep time is 2 seconds too. For 10 iterations the total time is about 40 seconds. However if I add a println to this function, the task gets called during the sleep function and the code is about twice as fast.
Is there a way to tell Julia, now it is a good time to run any pending tasks?
I use Knet, but the code showing this issue is independent of any neural network package.
Program output without print:
start sleep
end sleep
[ repeated 9 times ]
41.415683 seconds (52.30 k allocations: 15.053 GiB, 1.48% gc time, 0.13% compilation time)
Program output with print:
start work
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
start work
end sleep
start sleep
end sleep
24.652950 seconds (52.36 k allocations: 15.053 GiB, 3.21% gc time, 0.27% compilation time)
Notice that “start work” is between “start sleep” and “end sleep”.
I have run the code multiple times to exclude the compilation time.
I am using Julia 1.6.0 on Linux (official binaries).
Below is the full code:
import Base.iterate
# some computation which take about 2 seconds
# for n = 50 and size(A) = (2000,2000)
function get(A,n)
# adding this print has a surprising impact on the scheduling of the task
# println("start work")
for i = 1:n
A = A*A';
A = A/maximum(A)
end
return A
end
# "slow" iterator for 10 matrices of size sz
struct DataIter
sz::NTuple{2,Int}
end
function Base.iterate(d::DataIter,index=0)
if index == 10
return nothing
end
A = randn(Float32,d.sz)
B = get(A,50);
return (B,index+1)
end
# Iterator to prefetch data on a separate thread
mutable struct PrefetchDataIter{T}
iter::T
task::Union{Task,Nothing}
end
PrefetchDataIter(iter) = PrefetchDataIter(iter,nothing)
function Base.iterate(d::PrefetchDataIter,args...)
if d.task == nothing
out = iterate(d.iter,args...)
else
out = fetch(d.task)
end
if out == nothing
return nothing
else
next,state = out
d.task = Threads.@spawn iterate(d.iter,state)
return (next,state)
end
end
sz = (2000,2000)
@time for B in PrefetchDataIter(DataIter(sz))
println("start sleep")
sleep(2) # simulating work on GPU
println("end sleep")
end