Help me make this O(n^2) function faster?

I have a collection of interval events. Each interval has a (physical) location, a start time, and a stop time. There’s also a non-negative kernel function, e.g. kernel(dist) = ifelse(dist <= 100, 1, 0).

For each interval i, I want to find
sum(kernel(dist(i, j)) for j in <intervals ongoing when i ends>). Here’s a very simple but inefficient implementation:

function interval_counts(locs, starts, stops)
    n = length(locs)
    return map(1:n) do i
        inds_of_active_intervals = [j for j in 1:n if starts[j] < stops[i] <= stops[j]]
        distances = [distance(locs[i], locs[j]) for j in inds_of_active_intervals]
        return sum(kernel, dists)
    end
end

(Typically there are on the order of a million intervals, and at any time about 500 of them are ongoing.)

I keep feeling like there’s a simple way to make this much faster, probably relating to sorting intervals by their start or stop times. But I’ve not found any meaningful algorithmic speedup. (I have found that combining the body of the anonymous function in the example code into one line avoids allocating the index and distance vectors, but it’s not a big effect.)

Any suggestions?

1 Like

The most obvious way of speeding this up would be to have some sort of index or search tree of the intervals so that you do not have to search through all of them every time to determine which are active. The simplest example I could think of would be to sort the intervals and stop searching once you’ve passed through the relevant set. Even better would be to sort the intervals, store the indices of some pre-defined times in a Dict, then start searching from there.

To be honest, I usually find indexing and search tree schemes like this to be a huge pain in the neck to implement. It’s also usually not obvious to me what the absolute “best” such scheme would be, and I can usually think of many.

If anyone knows of a search tree or similar package for handling general cases easily, I would be very interested in that myself.

You can also of course eliminate the allocations of both the set of indices and the distances. I’m pretty sure this should be easy here, just swap your [ and ] with a ( and ). This will generate the values rather than allocating an array.

You can easily eliminate both allocations with something like

s = 0.0
istop = stops[i]
iloc = locs[i]
@inbounds for j in 1:n
    if starts[j] < istop ≤ stops[j]
        s += kernel(distance(iloc, locs[j]))
    end
end
return s

As @ExpandingMan, says, the next step would be to not have to search the whole list of intervals for every i. The typical data structure for this is an interval tree, and there is already a package in Julia for this:

To get even greater speedups, you would have to look at a deeper level … maybe there is some kind of fast-multipole approximation that you can use for your underlying problem.

6 Likes

Thanks! I hadn’t known about IntervalTrees, and am experimenting with it right now.

I’m not following your last point about a “fast-multipole approximation” - any chance you could elaborate on that?

1 Like

A common numerical problem is computing pairwise interactions between N bodies, where the interactions (your kernel) decay with distance. Any straightforward method is O(N²), quadratic in the number of interactions.

One of the numerical breakthroughs of the past half-century, however, has ben to realize that such computations can often be performed approximately (to any desired accuracy) faster than this, O(N log N) or even O(N). The basic reason is that you don’t need to repeatedly compute interactions between distant bodies, because in distant interactions you can lump nearby bodies together. The most famous such algorithm is the fast multipole method.

So, if you kernel is decaying, then you might be able to use such an algorithm for your underlying problem, whatever that is.

On the other hand, if your kernel(dist) = ifelse(dist <= 100, 1, 0) example function is typical, you can perhaps restrict your search to locations close to the current location. This is called a fixed-radius nearest-neighbor search, and there are various data structures and algorithms to perform such queries efficiently.

7 Likes

You can get fast queries for points within a fixed radius via e.g.

However, you need a hybrid datastructure that combines interval trees and metric trees to get good perf. I don’t know how this structure needs to look like (and would be unsurprised if this turned out to be a fun little research problem).

1 Like

Much appreciated, thank you! The kernel I’ll actually deploy looks more like dist -> exp(-dist^2); the scaling is such that for all intents and purposes each interval only “interacts” with a dozen or so nearby intervals. So a fixed-radius approach would probably work well. (In particular interval locations are spread over roughly a 20 by 20 kilometer area, and interactions are negligible over distances > 150m.)

This sounds like you can use a fast Gauss transform, which you can compute in linear time.

There is some ancient Julia code for this (GitHub - jwmerrill/FastGaussTransforms.jl: Fast algorithm for repeated evaluation of the convolution of a point set with a gaussian kernel.) which is probably not working in Julia 1.0, but might be useful as a starting point.

4 Likes

Hi, this is my first post on this forum. I was interested in Julia and was reading over here.

Your problem is usually solved by sorting the timestamps of the intervals. You don’t (necessarily) need a complex interval tree because you can process the intervals in order.

Basically, the idea is considering all timestamps, times at which an interval begins or ends, and go through them in order (there are at most 2n). When you encounter an interval that begins, you add it to a data structure. When you encounter an interval that ends, you remove it from there and you find the solution for it. With the difference that the data structure can be very simple. More on that later.

So, for example, in you function, you can use a simple set that allows to count all the numbers inside the set less than a given value. If you change the function later, you may need a slightly different data structure, but it would work more or less the same.

Pseudo-code (I don’t know julia at the moment):

- Takes stamps where negative index for interval means start and positive means end
- and is sorted according to the time that happens
- locs could be integrated inside times for better cache access
function interval_counts(locs::vector of numbers, stamps::vector of interval)
    i,j::int
    myset::custom set
    solution = 0
    for i in 1:n
         if (times[i].interval > 0)
             myset.add(loc[times[i]])
         else
             l = loc[-times[i]]
             myset.remove(l)
             solution += myset.countless(l+100)-myset.countless(l-100)
             - If you want the sum of distances instead of the number of intervals that are near, you
             - can use a set that allows sumless(value) in log n
         end
    end
    return solution
end

This algorithm runs in nlog n giving an exact answer (nlog n sorting + 2n steps in the for loop with a total 4n queries of time log n). And, depending on many things, you can use different data structures. For example, if the distance is really small and the locations are integers in a small range, you can create an array representing that range and add 1 to the locations the interval reach (this is like inserting it in the data structure) and then access a position when you remove and interval to see how many intervals are currently within reach.

I hope this helps,

1 Like