Macro to reorder conditional nested loop

I have a loop of the form

 for j=1:n
        for k = max(1,j-d_u):min(j-u-1,m)
            f(j, k)
        end
end

and it turns out that computing the limits on k in each iteration is adding a lot of overhead. Turns out that the loop can be rewritten by shifting the condition from k to j, which removes the overhead. I wonder if there is a macro somewhere to do this automatically?

The result that I’m looking for in this case, is

for j = max(d_u + 2, 1):min(m+u, n)
    for k = j-d_u:j-u-1
        f(j, k)
    end
end
for j = max(d_u + 2, 1, m+u+1):n
    for k = j-d_u:m
        f(j, k)
    end
end
for j=1:min(d_u + 1, m+u, n)
    for k = 1:j-u-1
        f(j, k)
    end
end
for j=max(1, m+u+1):min(d_u + 1, n)
    for k = 1:m
        f(j, k)
    end
end

Not that I’m aware of …
Further, maybe I’m missing something, but none of your rewrites are the “same” loop:

julia> f(a, b) = @show (a, b)
f (generic function with 1 method)

julia> n = 5; m = 4; d_u = 3; u = 1
1

# The original loop
julia> for j=1:n
               for k = max(1,j-d_u):min(j-u-1,m)
                   f(j, k)
               end
       end
(a, b) = (3, 1)
(a, b) = (4, 1)
(a, b) = (4, 2)
(a, b) = (5, 2)
(a, b) = (5, 3)

# The rewrites
julia> for j = max(d_u + 2, 1):min(m+u, n)
           for k = j-d_u:j-u-1
               f(j, k)
           end
       end
(a, b) = (5, 2)
(a, b) = (5, 3)

julia> for j = max(d_u + 2, 1, m+u+1):n
           for k = j-d_u:m
               f(j, k)
           end
       end

julia> for j=1:min(d_u + 1, m+u, n)
           for k = 1:j-u-1
               f(j, k)
           end
       end
(a, b) = (3, 1)
(a, b) = (4, 1)
(a, b) = (4, 2)

julia> for j=max(1, m+u+1):min(d_u + 1, n)
           for k = 1:m
               f(j, k)
           end
       end

I think the idea is that the four loops combined should be equivalent to the original single loop, which your example indeed illustrates.

2 Likes

Indeed, the idea is to move the branches from the inner loop to the outer loop. Taken together, the loops do the same. I am unsure if this is still an issue in more recent Julia versions.

At least on 1.10.4 I can still find plenty of situations where the version with four loops is faster than one with the single loop.

Some benchmarks
using BenchmarkTools

function one_loop(f, T, n, m, d_u, u)
    s = T(0)
    for j = 1:n
        for k = max(1, j - d_u):min(j - u - 1, m)
            s += f(j, k)
            # accumulate function outputs, so that the
            # function call does not get compiled away
        end
    end
    return s
end

function four_loops(f, T, n, m, d_u, u)
    s = T(0)
    for j = max(d_u + 2, 1):min(m+u, n)
        for k = j-d_u:j-u-1
            s += f(j, k)
        end
    end
    for j = max(d_u + 2, 1, m+u+1):n
        for k = j-d_u:m
            s += f(j, k)
        end
    end
    for j=1:min(d_u + 1, m+u, n)
        for k = 1:j-u-1
            s += f(j, k)
        end
    end
    for j=max(1, m+u+1):min(d_u + 1, n)
        for k = 1:m
            s += f(j, k)
        end
    end
    return s
end

f(x, y) = 1
@btime one_loop($f, Int, 100000, 100, 50, 20)     
#    36.600 μs (0 allocations: 0 bytes); 3000
@btime four_loops($f, Int, 100000, 100, 50, 20)   
#   100.100 μs (0 allocations: 0 bytes); 3000

@btime one_loop($f, Float64, 100000, 100, 50, 20)     
#   91.400 μs (0 allocations: 0 bytes); 3000.0
@btime four_loops($f, Float64, 100000, 100, 50, 20)   
#   44.300 μs (0 allocations: 0 bytes); 3000.0

f(x, y) = 1.
@btime one_loop($f, Float64, 100000, 100, 50, 20)     
#   91.400 μs (0 allocations: 0 bytes); 3000.0
@btime four_loops($f, Float64, 100000, 100, 50, 20)   
#   44.300 μs (0 allocations: 0 bytes); 3000.0

f(x, y) = x^2 + y
@btime one_loop($f, Int, 100000, 100, 50, 20)
#   76.200 μs (0 allocations: 0 bytes); 25064000
@btime four_loops($f, Int, 100000, 100, 50, 20)
#   52.900 μs (0 allocations: 0 bytes); 25064000

@btime one_loop($f, Float64, 100000, 100, 50, 20)
#   213.500 μs (0 allocations: 0 bytes); 2.5064e7
@btime four_loops($f, Float64, 100000, 100, 50, 20)
#    55.400 μs (0 allocations: 0 bytes); 2.5064e7

f(x, y) = sqrt(x^2 + y)
@btime one_loop($f, Float64, 100000, 100, 50, 20)
#   120.500 μs (0 allocations: 0 bytes); 258799.84936223726
@btime four_loops($f, Float64, 100000, 100, 50, 20)
#    70.600 μs (0 allocations: 0 bytes); 258799.84936223738

A number of these results don’t really make sense to me, though. Why is the Int-accumulator four loops version so much slower than the Float64 equivalent (while the opposite is true for the one loop version)? Why is the one loop version for f(x, y) = x^2 + y affected by the type of s, while that is not the case for the four loop version? Why is that Float64 four loop version for f(x, y) = x^2 + y slower than that for the more expensive f(x, y) = sqrt(x^2 + y)?

1 Like