Slow code with Union and Box

How can I improve this? I want to keep the Union{A,B,C,D} but make it faster. JET says c is captured and also there is runtime dispatch.

abstract type T end
struct A <: T x::Int end
struct B <: T x::Int end
struct C <: T x::Int end
struct D <: T x::Int end

f() = rand((A,B,C,D))(1)
function g(n)
    local c::Int = 0
    foreach(1:n) do _
        y = f()::Union{A,B,C,D}
        if y isa A
            c += 1
        elseif y isa B
            c += 2
        elseif y isa C
            c += 3
        elseif y isa D
            c += 4
        else
            error("unreachable")
        end

    end
    c
end
@report_opt g(10)
MethodInstance for g(::Int64)
  from g(n) @ Main 
Arguments
  #self#::Core.Const(g)
  n::Int64
Locals
  #65::var"#65#66"
  c@_4::Core.Box
  @_5::Int64
  c@_6::Union{}
Body::Int64
1 ─       Core.NewvarNode(:(#65))
β”‚         (c@_4 = Core.Box())
β”‚   %3  = c@_4::Core.Box
β”‚         (@_5 = 0)
β”‚   %5  = (@_5::Core.Const(0) isa Main.Int)::Core.Const(true)
└──       goto #3 if not %5
2 ─       goto #4
3 ─       Core.Const(:(Base.convert(Main.Int, @_5)))
└──       Core.Const(:(@_5 = Core.typeassert(%8, Main.Int)))
4 β”„ %10 = @_5::Core.Const(0)
β”‚         Core.setfield!(%3, :contents, %10)
β”‚         (#65 = %new(Main.:(var"#65#66"), c@_4))
β”‚   %13 = #65::var"#65#66"
β”‚   %14 = (1:n)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
β”‚         Main.foreach(%13, %14)
β”‚   %16 = Core.isdefined(c@_4, :contents)::Bool
└──       goto #6 if not %16
5 ─       goto #7
6 ─       Core.NewvarNode(:(c@_6))
└──       c@_6
7 β”„ %21 = Core.getfield(c@_4, :contents)::Any
β”‚   %22 = Core.typeassert(%21, Main.Int)::Int64
└──       return %22

I feel like there are quite a few ways to make this faster β€” mostly by rewriting completely differently

I’m assuming this is a toy example. could you share a few more details about the real use case?

1 Like

Reassigning c captured by the do block closure is responsible for the boxing. The annotation does introduce typeasserts to restore type inference immediately after. Wholly hypothetically, I would guess that a typed box would make this type stable. For now, you’d have to avoid closures, you could rewrite a foreach as a for loop since the do block closure can’t be used anywhere else. Not sure if that’ll help the runtime dispatch, I can’t tell visually where it is, but the rand((A,B,C,D)) and ::Union{A,B,C,D} is right over the Union-splitting limit.

1 Like

Replacing foreach with for does cure the boxing. As for types, the red bit is y::Union{A,B,C,D}. Replacing

f() = (rand((A,B,C,D))(1))::Union{A,B,C,D}

with

f() = (rand((A,B,C))(1))::Union{A,B,C}

seems to solve that problem, but if there are 4 things it’s not really a good solution.

abstract type T end
struct A <: T x::Int end
struct B <: T x::Int end
struct C <: T x::Int end
struct D <: T x::Int end

f() = (rand((A,B,C,D))(1))::Union{A,B,C,D}
function g(n)
    local c::Int = 0
    for _ in 1:n
        y = f()::Union{A,B,C,D}
        if y isa A
            c += 1
        elseif y isa B
            c += 2
        elseif y isa C
            c += 3
        elseif y isa D
            c += 4
        else
            error("unreachable")
        end

    end
    c
end
@code_warntype g(10)
Arguments
  #self#::Core.Const(g)
  n::Int64
Locals
  @_3::Union{Nothing, Tuple{Int64, Int64}}
  c::Int64
  y::Union{A, B, C, D}
  @_6::Int64
  @_7::Int64
  @_8::Int64
  @_9::Int64
  @_10::Int64
Body::Int64
1 ──       Core.NewvarNode(:(@_3))
β”‚          Core.NewvarNode(:(c))
β”‚          (@_6 = 0)
β”‚    %4  = (@_6::Core.Const(0) isa Main.Int)::Core.Const(true)
└───       goto #3 if not %4
2 ──       goto #4
3 ──       Core.Const(:(Base.convert(Main.Int, @_6)))
└───       Core.Const(:(@_6 = Core.typeassert(%7, Main.Int)))
4 ┄─       (c = @_6::Core.Const(0))
β”‚    %10 = (1:n)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
β”‚          (@_3 = Base.iterate(%10))
β”‚    %12 = (@_3 === nothing)::Bool
β”‚    %13 = Base.not_int(%12)::Bool
└───       goto #28 if not %13
5 ┄─ %15 = @_3::Tuple{Int64, Int64}
β”‚          Core.getfield(%15, 1)
β”‚    %17 = Core.getfield(%15, 2)::Int64
β”‚    %18 = Main.f()::Union{A, B, C, D}
β”‚    %19 = Core.apply_type(Main.Union, Main.A, Main.B, Main.C, Main.D)::Core.Const(Union{A, B, C, D})
β”‚          (y = Core.typeassert(%18, %19))
β”‚    %21 = (y isa Main.A)::Bool
└───       goto #10 if not %21
6 ── %23 = (c + 1)::Int64
β”‚          (@_7 = %23)
β”‚    %25 = (@_7 isa Main.Int)::Core.Const(true)
└───       goto #8 if not %25
7 ──       goto #9
8 ──       Core.Const(:(Base.convert(Main.Int, @_7)))
└───       Core.Const(:(@_7 = Core.typeassert(%28, Main.Int)))
9 ┄─       (c = @_7)
└───       goto #26
10 ─ %32 = (y::Union{B, C, D} isa Main.B)::Bool
└───       goto #15 if not %32
11 ─ %34 = (c + 2)::Int64
β”‚          (@_8 = %34)
β”‚    %36 = (@_8 isa Main.Int)::Core.Const(true)
└───       goto #13 if not %36
12 ─       goto #14
13 ─       Core.Const(:(Base.convert(Main.Int, @_8)))
└───       Core.Const(:(@_8 = Core.typeassert(%39, Main.Int)))
14 β”„       (c = @_8)
└───       goto #26
15 ─ %43 = (y::Union{C, D} isa Main.C)::Bool
└───       goto #20 if not %43
16 ─ %45 = (c + 3)::Int64
β”‚          (@_9 = %45)
β”‚    %47 = (@_9 isa Main.Int)::Core.Const(true)
└───       goto #18 if not %47
17 ─       goto #19
18 ─       Core.Const(:(Base.convert(Main.Int, @_9)))
└───       Core.Const(:(@_9 = Core.typeassert(%50, Main.Int)))
19 β”„       (c = @_9)
└───       goto #26
20 ─ %54 = (y::D isa Main.D)::Core.Const(true)
└───       goto #25 if not %54
21 ─ %56 = (c + 4)::Int64
β”‚          (@_10 = %56)
β”‚    %58 = (@_10 isa Main.Int)::Core.Const(true)
└───       goto #23 if not %58
22 ─       goto #24
23 ─       Core.Const(:(Base.convert(Main.Int, @_10)))
└───       Core.Const(:(@_10 = Core.typeassert(%61, Main.Int)))
24 β”„       (c = @_10)
└───       goto #26
25 ─       Core.Const(:(Main.error("unreachable")))
26 β”„       (@_3 = Base.iterate(%10, %17))
β”‚    %67 = (@_3 === nothing)::Bool
β”‚    %68 = Base.not_int(%67)::Bool
└───       goto #28 if not %68
27 ─       goto #5
28 β”„       return c

It’s becuase you weite c += 1 inside the closure body. The rebinding requires c to be boxed.

You can instead write

function g(n)
    local c = Ref{Int}(0)
    foreach(1:n) do _
        y = f()::Union{A,B,C,D}
        if y isa A
            c[] += 1
        elseif y isa B
            c[] += 2
        elseif y isa C
            c[] += 3
        elseif y isa D
            c[] += 4
        else
            error("unreachable")
        end

    end
    c
end

So that there’s no rebinding.

2 Likes

Putting c in a Ref removes the y problem but I don’t understand why.

abstract type T end
struct A <: T x::Int end
struct B <: T x::Int end
struct C <: T x::Int end
struct D <: T x::Int end

f() = (rand((A,B,C,D))(1))::Union{A,B,C,D}
function g(n)
    local c = Ref{Int}(0)
    foreach(1:n) do _
        y = f()::Union{A,B,C,D}
        if y isa A
            c[] += 1
        elseif y isa B
            c[] += 2
        elseif y isa C
            c[] += 3
        elseif y isa D
            c[] += 4
        else
            error("unreachable")
        end
    end
    c
end
@code_warntype g(10)
Arguments
  #self#::Core.Const(g)
  n::Int64
Locals
  #91::var"#91#92"{Base.RefValue{Int64}}
  c::Base.RefValue{Int64}
Body::Base.RefValue{Int64}
1 ─ %1 = Core.apply_type(Main.Ref, Main.Int)::Core.Const(Ref{Int64})
β”‚        (c = (%1)(0))
β”‚   %3 = Main.:(var"#91#92")::Core.Const(var"#91#92")
β”‚   %4 = Core.typeof(c)::Core.Const(Base.RefValue{Int64})
β”‚   %5 = Core.apply_type(%3, %4)::Core.Const(var"#91#92"{Base.RefValue{Int64}})
β”‚        (#91 = %new(%5, c))
β”‚   %7 = #91::var"#91#92"{Base.RefValue{Int64}}
β”‚   %8 = (1:n)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
β”‚        Main.foreach(%7, %8)
└──      return c

function h(n)
    local c = 0
    for _ in 1:n
        y = f()::Union{A,B,C,D}
        if y isa A
            c += 1
        elseif y isa B
            c += 2
        elseif y isa C
            c += 3
        elseif y isa D
            c += 4
        else
            error("unreachable")
        end
    end
    c
end
@code_warntype h(10)
Arguments
  #self#::Core.Const(h)
  n::Int64
Locals
  @_3::Union{Nothing, Tuple{Int64, Int64}}
  c::Int64
  y::Union{A, B, C, D}
Body::Int64
1 ──       (c = 0)
β”‚    %2  = (1:n)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
β”‚          (@_3 = Base.iterate(%2))
β”‚    %4  = (@_3 === nothing)::Bool
β”‚    %5  = Base.not_int(%4)::Bool
└───       goto #13 if not %5
2 ┄─ %7  = @_3::Tuple{Int64, Int64}
β”‚          Core.getfield(%7, 1)
β”‚    %9  = Core.getfield(%7, 2)::Int64
β”‚    %10 = Main.f()::Union{A, B, C, D}
β”‚    %11 = Core.apply_type(Main.Union, Main.A, Main.B, Main.C, Main.D)::Core.Const(Union{A, B, C, D})
β”‚          (y = Core.typeassert(%10, %11))
β”‚    %13 = (y isa Main.A)::Bool
└───       goto #4 if not %13
3 ──       (c = c + 1)
└───       goto #11
4 ── %17 = (y::Union{B, C, D} isa Main.B)::Bool
└───       goto #6 if not %17
5 ──       (c = c + 2)
└───       goto #11
6 ── %21 = (y::Union{C, D} isa Main.C)::Bool
└───       goto #8 if not %21
7 ──       (c = c + 3)
└───       goto #11
8 ── %25 = (y::D isa Main.D)::Core.Const(true)
└───       goto #10 if not %25
9 ──       (c = c + 4)
└───       goto #11
10 ─       Core.Const(:(Main.error("unreachable")))
11 β”„       (@_3 = Base.iterate(%2, %9))
β”‚    %31 = (@_3 === nothing)::Bool
β”‚    %32 = Base.not_int(%31)::Bool
└───       goto #13 if not %32
12 ─       goto #2
13 β”„       return c

The red there isnt a problem. So long as the union is fully split, the compiler knows what’s going on and there’s no performance penalty or dynamism.

2 Likes

hmm, we (as a community) have been here: Allocation and slow down when # of types involved increase (slower than C++ virtual methods) - #58 by peremato

This is almost 20x faster than the code posted by @Mason, but it cheats :smiley:

using MixedStructTypes

abstract type T end

@compact_structs S <: T begin
    struct A x::Int end
    struct B x::Int end
    struct C x::Int end
    struct D x::Int end
end

function h(t)
    if t === :A
        A(1)
    elseif t === :B
    	B(1)
    elseif t === :C
    	C(1)
    elseif t === :D
    	D(1)
    else
        error("unreachable")
    end
end

f() = h(rand((:A,:B,:C,:D)))

function g(n)
    c = 0
    for _ in 1:n
        y = f()
        if kindof(y) === :A
            c += 1
        elseif kindof(y) === :B
            c += 2
        elseif kindof(y) === :C
            c += 3
        elseif kindof(y) === :D
            c += 4
        else
            error("unreachable")
        end
    end
    c
end

the bottleneck in all the other implementations is

f() = (rand((A,B,C,D))(1))

but rewriting it as above doesn’t work when you have a Union

1 Like

Because c isn’t being reassigned anymore, you’re just mutating the assigned Ref instance.

I think the red hints at the slow runtime dispatch despite inferring the union, in other words the compiler naming the types in the union does not imply the Union-splitting optimization.

That seems right. Fundamentally rand((A,B,C,D)) is selecting a type at runtime, so using the result expression in a call must be runtime dispatch. There are too many types to involve the fast Union-split runtime dispatch, not that it would make f’s return type stable. Moreover reducing the types doesn’t really help because rand((A,B)) is still inferred as DataType, not a union, so the instantiation is inferred Any. The ::Union{A,B,C,D} in g just covers it up with a typeassert.

1 Like

wops, I was wrong, also this seems fine, it is only a tiny bit slower than the other implementation I proposed:

abstract type T end

struct A <: T x::Int end
struct B <: T x::Int end
struct C <: T x::Int end
struct D <: T x::Int end

function h(t)
    if t === :A
        A(1)
    elseif t === :B
    	B(1)
    elseif t === :C
    	C(1)
    elseif t === :D
    	D(1)
    else
        error("unreachable")
    end
end

f() = h(rand((:A,:B,:C,:D)))

function g(n)
    c = 0
    for _ in 1:n
        y = f()
        if y isa A
            c += 1
        elseif y isa B
            c += 2
        elseif y isa C
            c += 3
        elseif y isa D
            c += 4
        else
            error("unreachable")
        end
    end
    c
end

despite what @code_warntype says

1 Like

Unfortunately, the heuristics for when to color stuff ref in @code_warntype is quite simple and does not reflect what the compiler is actually able to do with the type.

Beside replacing foreach with for to remove the box, the solution here is to change f to use if-else statements such that it infers properly as Union{A, B, C, D} instead of the current typejoin(A, B, C, D).

3 Likes

By the way, I’d also like to point out that foreach is kinda the β€œWrong” function here, since what you’re really doing is a reduction. A more structured version of g would be:

g(n) = foldl(1:n, init=0) do c, _
    y = f()::Union{A,B,C,D}
    if y isa A
        c += 1
    elseif y isa B
        c += 2
    elseif y isa C
        c += 3
    elseif y isa D
        c += 4
    else
        error("unreachable")
    end
end

This doesn’t suffer from any closure-boxing problems because it’s a proper reduction with actual data locality, which helps make sure that e.g. this function could be parallelized without a race condition. The fully parallel-friendly version would look like

g(n) = mapreduce(+, 1:n, init=0) do _
    y = f()::Union{A,B,C,D}
    if y isa A
        1
    elseif y isa B
        2
    elseif y isa C
        3
    elseif y isa D
        4
    else
        error("unreachable")
    end
end

In general, any for loop with a pattern like

state = initial_state
for elem in itr
    state = some_function(state, f(elem))
end
state

can be re-written as

mapfoldl(f, some_function, itr; init=initial_state)

and if some_function is associative, then this should be parallelizable by just replacing mapfoldl with mapreduce.

4 Likes

How can you demonstrate that? With the original f I get Union{A,B,C,D} in code_warntype but JET complains

f() = (rand((A,B,C,D))(1))::Union{A,B,C,D}
function k(n)
    c = 0
    for _ in 1:n
        y = f()
        c += if y isa A
            1
        elseif y isa B
            2
        elseif y isa C
            3
        elseif y isa D
            4
        else
            error("unreachable")
        end
    end
    c
end
@code_warntype k(10)
Arguments
  #self#::Core.Const(k)
  n::Int64
Locals
  @_3::Union{Nothing, Tuple{Int64, Int64}}
  c::Int64
  y::Union{A, B, C, D}
  @_6::Int64
  @_7::Int64
  @_8::Int64
  @_9::Int64
Body::Int64
1 ──       (c = 0)
β”‚    %2  = (1:n)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
β”‚          (@_3 = Base.iterate(%2))
β”‚    %4  = (@_3 === nothing)::Bool
β”‚    %5  = Base.not_int(%4)::Bool
└───       goto #16 if not %5
2 ┄─ %7  = @_3::Tuple{Int64, Int64}
β”‚          Core.getfield(%7, 1)
β”‚    %9  = Core.getfield(%7, 2)::Int64
β”‚          (y = Main.f())
β”‚    %11 = c::Int64
β”‚    %12 = (y isa Main.A)::Bool
└───       goto #4 if not %12
3 ──       (@_6 = 1)
└───       goto #14
4 ── %16 = (y::Union{B, C, D} isa Main.B)::Bool
└───       goto #6 if not %16
5 ──       (@_7 = 2)
└───       goto #13
6 ── %20 = (y::Union{C, D} isa Main.C)::Bool
└───       goto #8 if not %20
7 ──       (@_8 = 3)
└───       goto #12
8 ── %24 = (y::D isa Main.D)::Core.Const(true)
└───       goto #10 if not %24
9 ──       (@_9 = 4)
└───       goto #11
10 ─       Core.Const(:(@_9 = Main.error("unreachable")))
11 β”„       (@_8 = @_9::Core.Const(4))
12 β”„       (@_7 = @_8)
13 β”„       (@_6 = @_7)
14 β”„ %32 = @_6::Int64
β”‚          (c = %11 + %32)
β”‚          (@_3 = Base.iterate(%2, %9))
β”‚    %35 = (@_3 === nothing)::Bool
β”‚    %36 = Base.not_int(%35)::Bool
└───       goto #16 if not %36
15 ─       goto #2
16 β”„       return c

If I instead use

f_if() = let T = rand((A,B,C,D))
    T === A ? A(1) :
    T === B ? B(1) :
    T === C ? C(1) :
    T === D ? D(1) :
    error("unreachable")
end

then I get y::Any from code_warntype and JET doesn’t complain.

Is code_warntype just giving the wrong answer?

to me this seems as simple as benchmarking the different approaches and understand what is going on under the hood as it is already well explained by @benny, @jakobnissen and @Mason.

Anyway, yes, @code_warntype is lying

the problem @jakobnissen is referring to is that (A, B, C, D)::Tuple{DataType, DataType, DataType, DataType}.

If instead you did something like

julia> f() = let n = rand(1:4)
           if n == 1
               A(1)
           elseif n == 2
               B(1)
           elseif n == 3
               C(1)
           elseif n == 4
               D(1)
           else
               error("unreachable")
           end
       end;

then the compiler is better able to reason about what’s going on in f.


Edit: actually, it still seems confused on my machine.

1 Like

@code_warntype says Any and @report_opt says No problems. Is @code_warntype just untrustworthy?

Yes but this link refers to the various code_ macros assuming concrete type specialization on top-level arguments that wouldn’t be.

Okay that makes me think this wasn’t optimized after all.

f_if doesn’t have the ::Union{A,B,C,D} covering up the type inference anymore. I wouldn’t expect this, but check @report_opt f() and @report_opt f_if() just in case JET gave up on inferring it and didn’t flag that. The Limitations section of the README says β€œHowever, if the argument types for a call cannot be inferred, JET does not analyze the callee. Consequently, a report of No errors detected does not imply that your entire codebase is free of errors. To increase the confidence in JET’s results use @report_opt to make sure your code is inferrible”, but I was never really sure what that says about when dynamic dispatch is flagged.