Why is `collect()` faster than a for loop in this numerical scheme? (corrected!)

For completeness, this is sufficient to make the original code generic and type stable:

import Base.isless


# User-facing struct that holds the instance data
struct MyDataFrame{T<:Real}
    m::Int
    xs::Vector{T}
    ys::Vector{T}
    xys::Vector{T}      # = xs .* ys
    omxs::Vector{T}     # = 1 .- xs

    function MyDataFrame(xs::Vector{T}, ys::Vector{T}) where T
        m = length(xs)

        return new{T}(m, xs, ys, xs .* ys, 1 .- xs)
    end
end


# Struct used internally in manipulatedata()
struct DataPoint{T<:Real}
    j::Int
    x::T
    y::T
    xy::T     # = x * y
    omx::T    # = 1 - x
end


# Overload isless() so that DataPoints sort on xy
isless(c1::DataPoint, c2::DataPoint) = isless(c1.xy, c2.xy)



function manipulatedata(mdf::MyDataFrame, coll = true::Bool)
    # Convert the user-supplied MyDataFrame into a vector of DataPoints
    datapoints = [DataPoint(j, mdf.xs[j], mdf.ys[j], mdf.xys[j], mdf.omxs[j]) for j in 1:mdf.m]
    
    for _ in 1:mdf.m
        dp_best, idx_best = findmax(datapoints)
    
        if coll     # This way is faster
            deleteat!(datapoints, idx_best)

            datapoints[:] = collect(
                DataPoint(
                    c.j,
                    c.x,
                    c.j < dp_best.j ? c.y * dp_best.omx : c.y - dp_best.xy,
                    c.j < dp_best.j ? c.xy * dp_best.omx : c.xy - c.x * dp_best.xy,
                    c.omx
                )
                for c in datapoints
            )
        else         # This way is slower
            for i in 1:idx_best-1
                datapoints[i] =
                    DataPoint(
                        datapoints[i].j,
                        datapoints[i].x,
                        datapoints[i].y * dp_best.omx,
                        datapoints[i].xy * dp_best.omx,
                        datapoints[i].omx
                    )
            end
            for i in idx_best+1:length(datapoints)
                datapoints[i] =
                    DataPoint(
                        datapoints[i].j,
                        datapoints[i].x,
                        datapoints[i].y - dp_best.xy,
                        datapoints[i].xy - datapoints[i].x * dp_best.xy,
                        datapoints[i].omx
                    )
            end
        
            deleteat!(datapoints, idx_best)
        end
    end

    return nothing
end


function timer(m)
    rectlist = MyDataFrame(rand(m), rand(m))

    @time r1 = manipulatedata(rectlist, true)
    @time r2 = manipulatedata(rectlist, false)
    @assert r1 == r2
end

timer(2000)
2 Likes