Loop unrolling for type stability

I (naively) thought that the compiler would do some basic loop unrolling. This is particularly important for generating type-stable code. Is there a way I can easily make the following code snippet stable?

abstract type A end
struct Ax <: A end
struct Ay <: A end
struct Az <: A end

bar(a::Ax) = println("Ax")
bar(a::Ay) = println("Ay")
bar(a::Az) = println("Az")

function foo()
    for a in (Ax(), Ay(), Az())
        bar(a)
    end
end

foo()

The @code_warnings output is

julia> @code_warntype foo()
MethodInstance for foo()
  from foo() in Main at <filetree>
Arguments
  #self#::Core.Const(foo)
Locals
  @_2::Union{Nothing, Tuple{Union{Ax, Ay, Az}, Int64}}
  a::Union{Ax, Ay, Az}
Body::Nothing
1 ─ %1  = Main.Ax()::Core.Const(Ax())
β”‚   %2  = Main.Ay()::Core.Const(Ay())
β”‚   %3  = Main.Az()::Core.Const(Az())
β”‚   %4  = Core.tuple(%1, %2, %3)::Core.Const((Ax(), Ay(), Az()))
β”‚         (@_2 = Base.iterate(%4))
β”‚   %6  = (@_2::Core.Const((Ax(), 2)) === nothing)::Core.Const(false)
β”‚   %7  = Base.not_int(%6)::Core.Const(true)
└──       goto #4 if not %7
2 β”„ %9  = @_2::Tuple{Union{Ax, Ay, Az}, Int64}
β”‚         (a = Core.getfield(%9, 1))
β”‚   %11 = Core.getfield(%9, 2)::Int64
β”‚         Main.bar(a)
β”‚         (@_2 = Base.iterate(%4, %11))
β”‚   %14 = (@_2 === nothing)::Bool
β”‚   %15 = Base.not_int(%14)::Bool
└──       goto #4 if not %15
3 ─       goto #2
4 β”„       return nothing

where a::Union{Ax, Ay, Az} and Tuple{Union{Ax, Ay, Az}, Int64} are red (not type stable).

Obviously, I could manually unroll the loop…

function foo_manual()
    bar(Ax())
    bar(Ay())
    bar(Az())
end

…which is indeed type stable. But this doesn’t scale for larger problems (plus this doesn’t seem very β€œJulian”).

Thanks in advance!

This package does the β€œmanual” unrolling for you:

map with Tuple arguments tends to be the easiest way.

5 Likes

This package does the β€œmanual” unrolling for you:

Thanks! I thought the same, but it doesn’t appear to work with this example:

using Unrolled

abstract type A end
struct Ax <: A end
struct Ay <: A end
struct Az <: A end

bar(a::Ax) = println("Ax")
bar(a::Ay) = println("Ay")
bar(a::Az) = println("Az")

@unroll function foo()
    for a in (Ax(), Ay(), Az())
        bar(a)
    end
end
@code_warntype foo()
MethodInstance for foo()
  from foo() in Main at <>
Arguments
  #self#::Core.Const(foo)
Locals
  @_2::Union{Nothing, Tuple{Union{Ax, Ay, Az}, Int64}}
  a::Union{Ax, Ay, Az}
Body::Nothing
1 ─ %1  = Main.Ax()::Core.Const(Ax())
β”‚   %2  = Main.Ay()::Core.Const(Ay())
β”‚   %3  = Main.Az()::Core.Const(Az())
β”‚   %4  = Core.tuple(%1, %2, %3)::Core.Const((Ax(), Ay(), Az()))
β”‚         (@_2 = Base.iterate(%4))
β”‚   %6  = (@_2::Core.Const((Ax(), 2)) === nothing)::Core.Const(false)
β”‚   %7  = Base.not_int(%6)::Core.Const(true)
└──       goto #4 if not %7
2 β”„ %9  = @_2::Tuple{Union{Ax, Ay, Az}, Int64}
β”‚         (a = Core.getfield(%9, 1))
β”‚   %11 = Core.getfield(%9, 2)::Int64
β”‚         Main.bar(a)
β”‚         (@_2 = Base.iterate(%4, %11))
β”‚   %14 = (@_2 === nothing)::Bool
β”‚   %15 = Base.not_int(%14)::Bool
└──       goto #4 if not %15
3 ─       goto #2
4 β”„       return nothing

same red as before.

This works:

julia> @unroll function foo(seq)
           @unroll for a in seq
               bar(a)
           end
       end
foo_unrolled_expansion_##314 (generic function with 1 method)

julia> @code_warntype foo((Ax(), Ay(), Az()))
MethodInstance for foo(::Tuple{Ax, Ay, Az})
  from foo(seq) in Main at /home/jishnu/.julia/packages/Unrolled/nMVH3/src/Unrolled.jl:127
Arguments
  #self#::Core.Const(foo)
  seq::Core.Const((Ax(), Ay(), Az()))
Locals
  a@_3::Ax
  a@_4::Ay
  a@_5::Az
Body::Nothing
1 ─ %1  = Base.getindex(seq, 1)::Core.Const(Ax())
β”‚         (a@_3 = %1)
β”‚         Main.bar(a@_3)
β”‚   %4  = Base.getindex(seq, 2)::Core.Const(Ay())
β”‚         (a@_4 = %4)
β”‚         Main.bar(a@_4)
β”‚   %7  = Base.getindex(seq, 3)::Core.Const(Az())
β”‚         (a@_5 = %7)
β”‚         Main.bar(a@_5)
β”‚   %10 = Main.nothing::Core.Const(nothing)
└──       return %10
1 Like

map with Tuple arguments tends to be the easiest way.

Aha! This works!

struct Ax <: A end
struct Ay <: A end
struct Az <: A end

bar(a::Ax) = println("Ax")
bar(a::Ay) = println("Ay")
bar(a::Az) = println("Az")

function foo()
    map(bar,(Ax(), Ay(), Az()))
    return
end

foo()
julia> @code_warntype foo()
MethodInstance for foo()
  from foo() in <>
Arguments
  #self#::Core.Const(foo)
Body::Nothing
1 ─ %1 = Main.Ax()::Core.Const(Ax())
β”‚   %2 = Main.Ay()::Core.Const(Ay())
β”‚   %3 = Main.Az()::Core.Const(Az())
β”‚   %4 = Core.tuple(%1, %2, %3)::Core.Const((Ax(), Ay(), Az()))
β”‚        Main.map(Main.bar, %4)
└──      return nothing

This works:

Ah I see. The trick is to pass the sequence as a function argument. I didn’t catch that nuance at first.

Thanks!