So I’m trying to reduce execution times of closure loss
that is returned from construct loss
.
I figured that multithreading might help me.
I am mostly interested in speeding up ForwardDiff.gradient(loss, p_initial)
but I have encountered some (scaling) problems with a plain loss(p_initial)
already.
construct_loss
returns a closure (I call it loss
) that will be run in a single-threaded manner, construct_loss_mt
is my attempt at parallelizing that thing.
The closure returned from construct_loss_mt
(called loss_mt
) takes a longer time to run on p_initial
(a plain Vector{Float64}
)…
I added some timing code to see how much time is spent in what I think are critical passages of my code (see startTimes[sampleId]
and stopTimes[sampleId]
).
Looking at the elapsed time in the console, I find that the multi-threaded loss_mt
takes a multiple of the time that single-threaded loss
takes for the same 2 lines of code, which should be doing the exact same mathematically.
Can someone explain why this is happening?
I’m grateful for every new insight into this problem and of course I would be most grateful for tips on how I could speed up the calculation of loss
, even if not using multithreading at all…
Code:
using PreallocationTools
using ForwardDiff
nSamples = 4
nFeatures = floor(Int64, 1e7)
X = [collect(LinRange(0, 0.5, nFeatures)).+sampleId for sampleId in 1:nSamples]
function construct_loss(X; chunksize=12)
nSamples = length(X)
nFeatures = length(first(X))
# 1 thread -> one dualcache vector for the difference vector
diffCache_d = dualcache(zeros(nFeatures))
startTimes = zeros(nSamples)
stopTimes = similar(startTimes)
function loss(p)
pred = sin(p[1]*p[2])
_diffCache = get_tmp(diffCache_d, p)
l = zero(eltype(p))
tSplit = time_ns()
for sampleId in 1:nSamples
startTimes[sampleId] = time_ns()
_diffCache .= X[sampleId]
_diffCache .-= pred
stopTimes[sampleId] = time_ns()
l += sum(abs2, _diffCache)
end
println("-"^40)
for (sampleId, startTime, stopTime) in zip(1:nSamples, startTimes, stopTimes)
println("(single-threaded) sample $(sampleId):\n\tstart: $((startTime-tSplit)/1e9),\n\tstop: $((stopTime-tSplit)/1e9),\n\t > difference: $((stopTime-startTime)/1e9)")
end
return l
end
end
function construct_loss_mt(X; chunksize=12)
nSamples = length(X)
nFeatures = length(first(X))
# 1 dualcache vector per thread
diffCaches_d = [dualcache(zeros(nFeatures), chunksize) for _ in 1:Threads.nthreads()]
# 1 element per thread
partialLosses_d = dualcache(zeros(Threads.nthreads()), chunksize)
startTimes = zeros(UInt64, nSamples)
stopTimes = similar(startTimes)
threadIds = fill(Int64(-1), nSamples)
function loss(p)
pred = sin(p[1]*p[2])
_pl = get_tmp(partialLosses_d, p)
_pl .= 0
startTimes .= 0
stopTimes .= 0
threadIds .= -1
tSplit = time_ns()
Threads.@threads for sampleId in 1:nSamples
tId = Threads.threadid()
_diffCache = get_tmp(diffCaches_d[tId], p)
startTimes[sampleId] = time_ns()
_diffCache .= X[sampleId]
_diffCache .-= pred
stopTimes[sampleId] = time_ns()
_pl[tId] += sum(abs2, _diffCache)
threadIds[sampleId] = tId
end
println("-"^40)
for (sampleId, startTime, stopTime, threadId) in zip(1:nSamples, startTimes, stopTimes, threadIds)
println("(multi-threaded) sample $(sampleId):\n\tthread $(threadId),\n\tstart: $((startTime-tSplit)/1e9),\n\tstop: $((stopTime-tSplit)/1e9),\n\t > difference: $((stopTime-startTime)/1e9)")
end
return sum(_pl)
end
end
loss = construct_loss(X)
loss_mt = construct_loss_mt(X)
p_initial = [2.0, 3.0]
@info "evaluation on Float64 arrays"
@assert isapprox(loss(p_initial), loss_mt(p_initial)) "The loss and loss_mt do not evaluate to the same!"
@info "ForwardDiff evaluation"
@assert isapprox.(ForwardDiff.gradient(loss, p_initial), ForwardDiff.gradient(loss_mt, p_initial)) |> all "The gradients are not the same!"