Why is `collect()` faster than a for loop in this numerical scheme? (corrected!)

I am a PhD student on Computer Science, and my work is on Operational Research, which is an area that values efficient algorithms. So I have some baggage from classes on topics like Complexity of algorithms, Algorithm Design, and Computer Architecture. But much I have also learned a lot lurking in forums, especially this one, XD, especially on how some things works on more modern machines, as architecture is always evolving.

Note, however, that I did many small changes and delivered all of them together. So you may be misattributing the source of the improvements. Removing the multiple small allocations did help a lot, and it came not from the integrated loop, but instead from separating the two approaches each in their own function, using @code_lowered to look into the loop version lowered code, and discovering that for some reason dp_best was being boxed (this is, a pointer to it on the heap was being allocated) and this was creating many small allocations. Then I followed a rarely useful Performance Tip and I was able to get the compiler to understand that the variable never changed in type, so the allocations disappeared. For this reason I recommended those three steps to you at the start of this thread.

2 Likes

Interesting. We are actually in the same boat, although you’ve been riding it for longer than me. I am doing an MS in operations research, but my program is housed in an industrial engineering department rather than CS, and our coursework and research focus almost entirely on theoretical issues. We model a problem, create a solution algorithm, and prove its accuracy and analyze its computational efficiency in terms of order of complexity. If implementation is considered at all, it is considered only to the extent that we can produce a plot showing that the algorithm’s actual runtime is, in fact, O(n^2) or whatever. (Many graduate without ever writing any code at all.)

This shows in my code: In this thread, my original implementation was theoretically “efficient” in the narrow sense that it was O(n^2), but was very practically inefficient because I had nobody around to bring these implementation details to my attention. I still think it’s better to start with a solid theoretical grasp of computational complexity and then self-study the implementation side than the opposite, so I am happy with my program overall, but yeah, I have a long way to go.

1 Like

For completeness, this is sufficient to make the original code generic and type stable:

import Base.isless


# User-facing struct that holds the instance data
struct MyDataFrame{T<:Real}
    m::Int
    xs::Vector{T}
    ys::Vector{T}
    xys::Vector{T}      # = xs .* ys
    omxs::Vector{T}     # = 1 .- xs

    function MyDataFrame(xs::Vector{T}, ys::Vector{T}) where T
        m = length(xs)

        return new{T}(m, xs, ys, xs .* ys, 1 .- xs)
    end
end


# Struct used internally in manipulatedata()
struct DataPoint{T<:Real}
    j::Int
    x::T
    y::T
    xy::T     # = x * y
    omx::T    # = 1 - x
end


# Overload isless() so that DataPoints sort on xy
isless(c1::DataPoint, c2::DataPoint) = isless(c1.xy, c2.xy)



function manipulatedata(mdf::MyDataFrame, coll = true::Bool)
    # Convert the user-supplied MyDataFrame into a vector of DataPoints
    datapoints = [DataPoint(j, mdf.xs[j], mdf.ys[j], mdf.xys[j], mdf.omxs[j]) for j in 1:mdf.m]
    
    for _ in 1:mdf.m
        dp_best, idx_best = findmax(datapoints)
    
        if coll     # This way is faster
            deleteat!(datapoints, idx_best)

            datapoints[:] = collect(
                DataPoint(
                    c.j,
                    c.x,
                    c.j < dp_best.j ? c.y * dp_best.omx : c.y - dp_best.xy,
                    c.j < dp_best.j ? c.xy * dp_best.omx : c.xy - c.x * dp_best.xy,
                    c.omx
                )
                for c in datapoints
            )
        else         # This way is slower
            for i in 1:idx_best-1
                datapoints[i] =
                    DataPoint(
                        datapoints[i].j,
                        datapoints[i].x,
                        datapoints[i].y * dp_best.omx,
                        datapoints[i].xy * dp_best.omx,
                        datapoints[i].omx
                    )
            end
            for i in idx_best+1:length(datapoints)
                datapoints[i] =
                    DataPoint(
                        datapoints[i].j,
                        datapoints[i].x,
                        datapoints[i].y - dp_best.xy,
                        datapoints[i].xy - datapoints[i].x * dp_best.xy,
                        datapoints[i].omx
                    )
            end
        
            deleteat!(datapoints, idx_best)
        end
    end

    return nothing
end


function timer(m)
    rectlist = MyDataFrame(rand(m), rand(m))

    @time r1 = manipulatedata(rectlist, true)
    @time r2 = manipulatedata(rectlist, false)
    @assert r1 == r2
end

timer(2000)
2 Likes

Isn’t this wasteful? The rhs allocates an array, it doesn’t directly write into datapoints. So it first allocates, and then copies into datapoints. Isn’t it better to just re-bind datapoints instead?

Of course it is. To be clear the code I posted only makes the minimal changes needed to make the original code generic and type stable and when you time it you find that the wasteful collect alternative is indeed slower than the loop alternative, since it no longer needs to deal with boxed variables.

1 Like

I’d like to add to the list

  • Performance tips section (already mentioned)
  • Profile your code and/or run JET.jl
  • Know that tools like Ctulhu.jl and SnoopCompile.jl exist
2 Likes