Prefered pattern for generators?

Hi. First thread.

we all know the generator pattern from Python and elsewhere.

def mygenerator():
    for i in range(10):
        yield i

I know the “proper” way to do this is to write a struct with a method on Base.iterate, but sometimes you don’t need to / want to implement a struct for something that is more easily described in terms of procedures. I’ve seen the pattern with Channels floating around online, and I know using closures is also a popular strategy for this in other languages. i.e.

function countto(n)
    count = 0
    () -> (count >= n ? nothing : count += 1)
end


function closuretest()
    a_cool_million = countto(1_000_000)
    while (i = a_cool_million()) != nothing end
end


function countto2(n)
    Channel(ctype=Int) do c
        count = 0
        while count >= n
            put!(c += 1, count)
        end
    end
end

function channeltest()
    for i in countto2(1_000_000) end
end

I like the way the channel version a little more because I like that it automatically provides an interface for for loops and, to my eye, the code is a little easier to with the loop than with the closure.

then I ran the code a few times to see how each would look in terms of performance:

@time closuretest()
@time closuretest()
@time closuretest()
@time closuretest()
println()
@time channeltest()
@time channeltest()
@time channeltest()
@time channeltest()

(did the extra runs to make sure the JIT was warm)

  0.064392 seconds (2.01 M allocations: 30.837 MiB)
  0.051788 seconds (2.00 M allocations: 30.510 MiB, 3.74% gc time)
  0.049927 seconds (2.00 M allocations: 30.510 MiB, 7.13% gc time)
  0.046646 seconds (2.00 M allocations: 30.510 MiB, 0.60% gc time)

  0.189093 seconds (331.03 k allocations: 16.900 MiB, 1.23% gc time)
  0.000072 seconds (31 allocations: 2.266 KiB)
  0.000045 seconds (31 allocations: 2.266 KiB)
  0.000045 seconds (31 allocations: 2.266 KiB)

So, yeah, wasn’t expecting that. Obviously the channel generator pattern is better from a performance perspective, which I wouldn’t have expected, since it has to hit the scheduler on every iteration (or not?)

Anyway, I don’t put too much stock in the current poor performance of closures, since that can be fixed. I’m mostly just wondering which pattern is considered more idiomatic in Julia.

Check out ResumableFunctions.jl.

2 Likes

First, you’ve got a bug in your Channel-based implementation, making it iterate over an empty set. This is why you get so good performance. Of all implementations that I tried, the Channel-based is actually by far the worst in terms of performance (although it is easy to write).

I’m too new to Julia to comment on the “idiomaticity” of all options. However, here are a few tests that give an idea of the performance of various possible implementations. These tests have been run using julia -O3.

using BenchmarkTools
N = 1_000_000

Standard julia range iterator

let’s start with Julia’s standard UnitRange to get a base line:

julia> @btime for i in 1:$N end
  1.815 ns (0 allocations: 0 bytes)

Closure implementation

The main drawback of this one is that it does not follow Julia’s iteration interface, meaning it is not usable with standard algorithms. It is not very fast, but not too slow either.

function countto1(n)
    count = 0
    () -> (count >= n ? nothing : count += 1)
end
julia> @btime let itr = countto1($N)
    while (i = itr()) !== nothing end
end
  45.490 ms (1999492 allocations: 30.51 MiB)

Channel implementation

This seems to be the recommended way of implementing an iterator now that produce and consume are deprecated. Easy to write, it follows the Julia interface for iterators and can be used everywhere. However, Channels seem to be extreeeeemely slow for this kind of tasks.

function countto2(n)
    Channel(ctype=Int) do c
        count = 0
        while count < n
            put!(c, count+=1)
        end
    end
end
julia> @btime for i in countto2($N) end
  1.300 s (3999523 allocations: 61.03 MiB)

ResumableFunctions

As noted by @mohamed82008, ResumableFunctions.jl provides a macro-based way to generate faster iterators. As far as I understand it, it transforms your function into a finite state machine which is then used to implement the iteration interface. Again, it is very easy to write and produces standard iterators. And it is much faster than the Channel-based implementation (and even faster than the closure-based one), though not nearly as fast as the native range iterator.

using ResumableFunctions

@resumable function countto3(n)
    count = 0
    while count < n
        count += 1
        @yield count
    end
end
julia> @btime for i in countto3($N) end
  26.495 ms (999491 allocations: 15.25 MiB)

Custom type implementing the iterator interface

It is much less easy to write, but follows the same kind of logic than your closure-based one. Advantages : super-fast and standard (it is probably how Base.UnitRange is implemented)

struct CountTo
    n :: Int
end

Base.iterate(c::CountTo, state=0) = state >= c.n ? nothing : (state+1, state+1)
Base.length(c::CountTo) = c.n
julia> @btime for i in CountTo($N) end
  1.883 ns (0 allocations: 0 bytes)
9 Likes

Nice write-up! Note however that the ~1.8 ns benchmarks are the result of compiler optimizations removing the loops. 1.8 ns is just a handful of clock cycles and obviously not enough for a million iterations.

1 Like

That would make perfect sense, since 1.8ns is indeed a very short amount of time. On the other hand, I get consistent results when actually using the iterated values. For example:

julia> @assert sum(CountTo(N)) == div(N*(N+1), 2)
julia> @btime  sum(CountTo($N))
  5.047 ns (0 allocations: 0 bytes)
500000500000

Or is Julia smart enough to optimize such a loop too?


Do you know of a reliable way to find what gets actually computed (and what gets optimized away)? Running @code_native sum(CountTo(N)) produces a list of instructions which I do not feel confident enough to interpret correctly…

With ResumableFunctions you can also get it to optimize away the loop if you specify the type of n:

using BenchmarkTools
using ResumableFunctions

N = 1_000_000

@resumable function countto3(n::Int)
    count = 0
    while count < n
        count += 1
        @yield count
    end
end

@btime for i in countto3($N) end # 1.503 ns (0 allocations: 0 bytes)

This is because n becomes a field in the struct that @resumable generates (see https://benlauwens.github.io/ResumableFunctions.jl/stable/internals.html#Type-definition-1).

2 Likes

Again, just do a back-of-the-envelope calculation of a sum with 1_000_000 elements and compare it to 5 ns.

1 Like

Thanks for doing that full write-up! Really good stuff.

To your last point, that the iterate method is the fastest: I know it really is fast. It’s just such a PITA implement a struct and an iterate method every dang time you want a customized lazy iterator! It’s what I do when I want to “do it properly,” for library code or whatever, but I spend too much time trying to find ways to avoid it when it’s not a significant performance bottleneck. Implementing a struct and an iterate method is just… My brain can’t spare the CPU cycles when I’m in a hurry!

I’ll have to look more seriously at resumable functions.

Regarding the bug in my channel implementation: story of my life. My brain naturally seems to think “while” means “until” in a programming context. Uhg.

If your iterator is happened to be some kind of “state transition” and ignore the input, you can use IterTools.iterated. It looks like it’s also compiled away, like CountTo.

julia> @btime sum(Iterators.take(iterated(x -> x + 1, 1), $N))
  4.309 ns (0 allocations: 0 bytes)
500000500000

julia> @btime sum(CountTo($N))
  4.036 ns (0 allocations: 0 bytes)
500000500000

It could be useful if you don’t even want to write a function at top-level scope.

1 Like

As mentioned above, you could either do a back-of-the-envelope calculation. A 3 GHz CPU means three clock cycles per nanosecond. Usually you’ll need at least a single clock cycle per iteration.

Another method is to simply increase/decrease the problem size by a factor 10, and make sure that the elapsed time also changes by around a factor 10 (for a linear algorithm). If the time stays the same, you are likely not timing what you think you’re timing:

julia> @btime for i in CountTo($N) end
  1.460 ns (0 allocations: 0 bytes)

julia> @btime for i in CountTo(10 * $N) end
  1.462 ns (0 allocations: 0 bytes)

julia> @btime for i in countto3($N) end
  24.032 ms (999491 allocations: 15.25 MiB)

julia> @btime for i in countto3(10 * $N) end
  246.844 ms (9999491 allocations: 152.58 MiB)
3 Likes

I’m sorry I must not have been clear. I did not mean to imply that Julia was really doing 1000000 operations in under 2ns; I was rather trying to understand exactly what was computed during these 2ns, and how.

Using @code_native seems to be my best bet to understand exactly what Julia does behind the scenes.

foo(N) = let s = 0
    for i in CountTo(N)
        s += i
    end
    s
end
julia> @code_native foo(1_000_000)
# I edited this output by hand to remove comments
        testq   %rdi, %rdi
        jle     L32
        leaq    -1(%rdi), %rdx
        leaq    -2(%rdi), %rax
        mulxq   %rax, %rax, %rcx
        shldq   $63, %rax, %rcx
        leaq    (%rcx,%rdi,2), %rax
        addq    $-1, %rax
        retq
L32:
        xorl    %eax, %eax
        retq
        nopw    %cs:(%rax,%rax)

From my (limited) understanding, Julia seems to recognize the pattern of the summation of a sequence in arithmetic progression, and replaces it by the adequate formula (hence the presence of a mulxq instruction).


In any case, I’m sorry because I think I made this discussion drift too far away from its original topic. To get back on track, here are updated benchmarks for all implementations proposed so far. In order for the timings to be meaningful, this time I used the following loop, which Julia does not seem to be able to optimize:

julia> @btime let s = 0
    for i in 1:$N
        s += i&1
    end
    @assert s == div($N, 2)
end
  61.049 μs (0 allocations: 0 bytes)

and this time it behaves as it should, with run times depending linearly on the number of iterations:

julia> @btime let s = 0
    for i in 1:$(10*N)
        s += i&1
    end
    @assert s == $N * 5
end
  612.104 μs (0 allocations: 0 bytes)

And here is the comparison of all options. Basically, all solutions perform as well as the standard implementation, with the notable exception of Channel- and closure-based implementations. I would thus (personally) consider ResumableFunctions.jl the best candidate here, except maybe if what we want to achieve is available in IterTools and/or Base.Iterators.

Generator memory [Mb] allocs min time [µs] avg time [µs] max time [µs] comment
struct 0.00 0 60.24 61.23 174.18
IterTools 0.00 0 61.05 61.90 138.94 as per @tkf’s comment
UnitRange 0.00 0 61.03 61.96 151.89
ResumableFunctions 0.00 0 61.03 62.64 136.12 optimized, as per @tkoolen’s comment
closure 45.75 2998470 91216.49 93338.52 156527.15
Channel 91.53 5997990 1561830.68 1604676.03 1699951.08

If anyone is interested, the script used to generate this table is available here:

2 Likes

Yes, I usually use code_native (but with Intel syntax!). You can also use code_llvm if you prefer that. It is amazing how clever the compiler is. This example is relatively simple, it’s able to turn far more complicated expressions into constant time formulas. FYI, I’ve annotated the code below with what happens.

testq   %rdi, %rdi          ; test input N
jle     L32                 ; if N <= 0, jump to L32
leaq    -1(%rdi), %rdx      ; rdx = N - 1
leaq    -2(%rdi), %rax      ; rax = N - 2
mulxq   %rax, %rax, %rcx    ; rcx:rax = rdx * rax = (N-1)*(N-2)
shldq   $63, %rax, %rcx     ; rcx = rcx:rax / 2 = (N-1)*(N-2) / 2
leaq    (%rcx,%rdi,2), %rax ; rax = rcx + rdi * 2 = (N-1)*(N-2) / 2 + N*2
addq    $-1, %rax           ; rax = rax - 1 = (N-1)*(N-2) / 2 + N*2 - 1
retq
L32:
xorl    %eax, %eax          ; rax = 0
retq

The return value (N-1)*(N-2) / 2 + N*2 - 1 is the same as N*(N + 1) / 2, so it doesn’t figure out the optimal formula, but close enough.

Also note that if N had been a constant (const N), the whole thing would instead be replaced by the constant 500000500000.

4 Likes

Thanks a lot !