Code in Threads.@threads takes longer than in single-threaded code

So I’m trying to reduce execution times of closure loss that is returned from construct loss.
I figured that multithreading might help me.
I am mostly interested in speeding up ForwardDiff.gradient(loss, p_initial) but I have encountered some (scaling) problems with a plain loss(p_initial) already.

construct_loss returns a closure (I call it loss) that will be run in a single-threaded manner, construct_loss_mt is my attempt at parallelizing that thing.
The closure returned from construct_loss_mt (called loss_mt) takes a longer time to run on p_initial (a plain Vector{Float64})…

I added some timing code to see how much time is spent in what I think are critical passages of my code (see startTimes[sampleId] and stopTimes[sampleId]).
Looking at the elapsed time in the console, I find that the multi-threaded loss_mt takes a multiple of the time that single-threaded loss takes for the same 2 lines of code, which should be doing the exact same mathematically.

Can someone explain why this is happening?

I’m grateful for every new insight into this problem and of course I would be most grateful for tips on how I could speed up the calculation of loss, even if not using multithreading at all…

Code:

using PreallocationTools
using ForwardDiff

nSamples = 4
nFeatures = floor(Int64, 1e7)


X = [collect(LinRange(0, 0.5, nFeatures)).+sampleId  for sampleId in 1:nSamples]


function construct_loss(X; chunksize=12)

    nSamples = length(X)
    nFeatures = length(first(X))

    # 1 thread -> one dualcache vector for the difference vector
    diffCache_d = dualcache(zeros(nFeatures))

    startTimes = zeros(nSamples)
    stopTimes = similar(startTimes)

    function loss(p)
        pred = sin(p[1]*p[2])

        _diffCache = get_tmp(diffCache_d, p)

        l = zero(eltype(p))

        tSplit = time_ns()

        for sampleId in 1:nSamples
            startTimes[sampleId] = time_ns()
            _diffCache .= X[sampleId]
            _diffCache .-= pred
            stopTimes[sampleId] = time_ns()
            l += sum(abs2, _diffCache)
        end

        println("-"^40)

        for (sampleId, startTime, stopTime) in zip(1:nSamples, startTimes, stopTimes)
            println("(single-threaded) sample $(sampleId):\n\tstart: $((startTime-tSplit)/1e9),\n\tstop: $((stopTime-tSplit)/1e9),\n\t > difference: $((stopTime-startTime)/1e9)")
        end

        return l
    end
end


function construct_loss_mt(X; chunksize=12)

    nSamples = length(X)
    nFeatures = length(first(X))

    # 1 dualcache vector per thread
    diffCaches_d = [dualcache(zeros(nFeatures), chunksize) for _ in 1:Threads.nthreads()]

    # 1 element per thread
    partialLosses_d = dualcache(zeros(Threads.nthreads()), chunksize)

    startTimes = zeros(UInt64, nSamples)
    stopTimes = similar(startTimes)
    threadIds = fill(Int64(-1), nSamples)

    function loss(p)
        pred = sin(p[1]*p[2])

        _pl = get_tmp(partialLosses_d, p)
        _pl .= 0

        startTimes .= 0
        stopTimes .= 0
        threadIds .= -1

        tSplit = time_ns()

        Threads.@threads for sampleId in 1:nSamples

            tId = Threads.threadid()
            _diffCache = get_tmp(diffCaches_d[tId], p)

            startTimes[sampleId] = time_ns()
            _diffCache .= X[sampleId]
            _diffCache .-= pred
            stopTimes[sampleId] = time_ns()
            _pl[tId] += sum(abs2, _diffCache)

            threadIds[sampleId] = tId
        end
        
        println("-"^40)

        for (sampleId, startTime, stopTime, threadId) in zip(1:nSamples, startTimes, stopTimes, threadIds)
            println("(multi-threaded) sample $(sampleId):\n\tthread $(threadId),\n\tstart: $((startTime-tSplit)/1e9),\n\tstop: $((stopTime-tSplit)/1e9),\n\t > difference: $((stopTime-startTime)/1e9)")
        end

        return sum(_pl)
    end
end


loss = construct_loss(X)
loss_mt = construct_loss_mt(X)

p_initial = [2.0, 3.0]

@info "evaluation on Float64 arrays"

@assert isapprox(loss(p_initial), loss_mt(p_initial)) "The loss and loss_mt do not evaluate to the same!"

@info "ForwardDiff evaluation"

@assert isapprox.(ForwardDiff.gradient(loss, p_initial), ForwardDiff.gradient(loss_mt, p_initial)) |> all "The gradients are not the same!"