Access a tuple of structs using the values of an integer array

I have a tuple of heterogeneous structures:

struct Foo{T}
    a :: T
    b :: T
end

foos = (Foo(1,2), Foo(3.0,4.0))

and I have a 3D grid which is separated into different regions such that in a given region I want to use a specific structure from the tuple. To do so, I create a 3D array of integers with the serial numbers of the structures:

I = zeros(Int, (100,200,300))
@. I[1:50,:,:] = 1
@. I[51:100,:,:] = 2

such that being at the grid point i,j,k I can access the desired structure as foos[I[i,j,k]]. The MWE code is the following:

function bar1(foos, I)
    @inbounds for i in eachindex(I)
        foo = foos[I[i]]
    end
    return nothing
end

function bar2(foos, I)
    @inbounds for i in eachindex(I)
        foo = I[i] == 1 ? foos[1] : foos[2]
    end
    return nothing
end

However, the bar1 function is much slower than bar2:

@btime bar1($foos, $I)   # 145.988 ms (12000000 allocations: 457.76 MiB)
@btime bar2($foos, $I)   # 1.497 ns (0 allocations: 0 bytes)

How can I fix the first function in order to have the fast and generic code?

The full code
using BenchmarkTools

struct Foo{T}
    a :: T
    b :: T
end

foos = (Foo(1,2), Foo(3.0,4.0))

I = zeros(Int, (100,200,300))
@. I[1:50,:,:] = 1
@. I[51:100,:,:] = 2

function bar1(foos, I)
    @inbounds for i in eachindex(I)
        foo = foos[I[i]]
    end
    return nothing
end

function bar2(foos, I)
    @inbounds for i in eachindex(I)
        foo = I[i] == 1 ? foos[1] : foos[2]
    end
    return nothing
end

@btime bar1($foos, $I)
@btime bar2($foos, $I)

I think for the “fast” one the compiler may have removed the whole loop. I think the compiler is smart enough to optimize the if else where you specify both values you expect. But in the other function, the integers you index with could be anything, so the compiler seems not to analyze that only values 1 and 2 can index into the tuple and so it could transform the code. The loop is type unstable so it cannot be compiled away.

By the way, if you did anything with your foo value in the fast one, it would probably also be much slower as that value is not type stable. Might be handled by small union splitting though, depending on the code you’d be running with foo.

Does it mean I can not do anything in this situation?

By the way, why not knowing the content of I is the problem? The @inbounds macro should switch off all check ups, no?

In turn, the @code_warntype does not show any problems:

@code_warntype bar1(foos, I)
MethodInstance for bar1(::Tuple{Foo{Int64}, Foo{Float64}}, ::Array{Int64, 3})
  from bar1(foos, I) @ Main /media/storage/projects/julia/dev/Maxwell/examples/test.jl:15
Arguments
  #self#::Core.Const(bar1)
  foos::Tuple{Foo{Int64}, Foo{Float64}}
  I::Array{Int64, 3}
Locals
  @_4::Union{Nothing, Tuple{Int64, Int64}}
  val::Nothing
  i::Int64
  foo::Union{Foo{Float64}, Foo{Int64}}
Body::Nothing
1 ─       Core.NewvarNode(:(val))
│         nothing
│   %3  = Main.eachindex(I)::Base.OneTo{Int64}
│         (@_4 = Base.iterate(%3))
│   %5  = (@_4 === nothing)::Bool
│   %6  = Base.not_int(%5)::Bool
└──       goto #4 if not %6
2 ┄ %8  = @_4::Tuple{Int64, Int64}
│         (i = Core.getfield(%8, 1))
│   %10 = Core.getfield(%8, 2)::Int64
│   %11 = Base.getindex(I, i)::Int64
│         (foo = Base.getindex(foos, %11))
│         (@_4 = Base.iterate(%3, %10))
│   %14 = (@_4 === nothing)::Bool
│   %15 = Base.not_int(%14)::Bool
└──       goto #4 if not %15
3 ─       goto #2
4 ┄       (val = nothing)
│         nothing
│         val
└──       return Main.nothing

Also, I checked that if I will use foo in the loop (e.g. by adding something like c = sin(foo.a) * exp(foo.b)), the ratio of timings remain the same.

Does anyone have any ideas what I can do here?

The @inbounds macro removes checks that the indices are inside the bounds of the arrays, I don’t think it helps with type stability and other optimizations.

The first thing to do is to make bar1 and bar2 comparable since as @jules said your bar2 is not actually doing anything. Here are versions that do real work:

function bar1(foos, I)
    s = 0.0
    @inbounds for i in eachindex(I)
        s += foos[I[i]].a
    end
    return s
end

function bar2(foos, I)
    s = 0.0
    @inbounds for i in eachindex(I)
        s += I[i] == 1 ? foos[1].a : foos[2].a
    end
    return s
end

julia> @btime bar1($foos, $I)
  142.794 ms (12000000 allocations: 457.76 MiB)

julia> @btime bar2($foos, $I)
  6.462 ms (0 allocations: 0 bytes)

Now to improve the performance of bar1, one thing I can think of is to move the indices to the type domain:

MyVals = Union{Val{0},Val{1},Val{2}}
Ib = MyVals[Val(0) for i in 1:100, j in 1:200, k in 1:300]
@. Ib[1:50,:,:] = Val(1);
@. Ib[51:100,:,:] = Val(2);

getfoo(foos, ::Val{N}) where N = foos[N]

function bar1b(foos, I)
    s = 0.0
    @inbounds for i in eachindex(I)
        s += getfoo(foos, I[i]).a
    end
    return s
end

julia> @btime bar1b($foos, $Ib)
  5.599 ms (0 allocations: 0 bytes)

I guess this helps the compiler notice that it can hardcode the possible cases for foos[i].

1 Like

Thank you, but this approach is not generic enough. You have to manually redefine MyVals every time the length of foos changes. Moreover, the same speed can be achieved with

getfoo(foos, N) = N == 1 ? foos[1] : foos[2]

without any Val magic.

I think about a more general approach, where the length of foos and the content of I are not know during compile time. Is it possible to do something in this case?

1 Like

I guess a possible solution is to use a @generated function to do the branching, in that case I think you function could retrieve the length of foos and specialize the function to that even if not fixed at compile time. I don’t have the code for that though

1 Like

@fedoroff indeed I assumed you had a concrete problem with a known tuple size. You can still make it work for unknown/dynamic tuple sizes if there’s a reasonable max size. For example the following performs just as well for me:

MyVals = Union{(Val{i} for i in 0:20)...}
Ib = MyVals[Val(0) for i in 1:100, j in 1:200, k in 1:300];
@. Ib[1:50,:,:] = Val(1);
@. Ib[51:100,:,:] = Val(2);

julia> @btime bar1b($foos, $Ib)
  5.203 ms (0 allocations: 0 bytes)

This should work for any tuple of size \leq 20. Of course further increasing the max size will produce annoying compile times at some point…

(If you just don’t like having to change MyVals when you change your particular tuple size, you can also do something like MyVals = Union{(Val{i} for i in 0:length(foos))...}.)

Great! Thank you @sijo. I think, this is exactly what I need.

Good idea. I have to read a bit more about generated functions.

And here’s a simpler solution based on @Tortar’s idea:

struct Foo{T}
    a :: T
    b :: T
end
foos = (Foo(1,2), Foo(3.0,4.0))
I = zeros(Int, (100,200,300));
@. I[1:50,:,:] = 1;
@. I[51:100,:,:] = 2;
case_expr(idx, n) = idx == n ?
    :(foos[$idx]) :
    :(i == $idx ? foos[$idx] : $(case_expr(idx+1, n)))
@generated getfoo(foos, i) = case_expr(1, fieldcount(foos))

function bar(foos, I)
    s = 0.0
    @inbounds for i in eachindex(I)
        s += getfoo(foos, I[i]).a
    end
    return s
end

julia> @btime bar($foos, $I)
  5.696 ms (0 allocations: 0 bytes)

(I’m wary about generated functions due to the possible undefined behavior but it seems safe here.)

3 Likes

This is beautiful! Generated function with recursion! Bravo!
I am not sure that understand 100% how exactly it works, but it works.

One small comment. Would it be better to put case_expr under getfoo definition as follows

@generated function getfoo2(foos, i)
    case_expr(idx, n) = idx == n ?
        :(foos[$idx]) :
        :(i == $idx ? foos[$idx] : $(case_expr(idx+1, n)))
    return case_expr(1, fieldcount(foos))
end

in order to avoid problems with global variable foos used in case_expr?

1 Like

This is why generated functions spook me :slight_smile: It’s very easy to get into maybe-undefined-behavior territory without noticing. From the manual:

  1. Generated functions must not mutate or observe any non-constant global state (including, for example, IO, locks, non-local dictionaries, or using hasmethod). This means they can only read global constants, and cannot have any side effects. In other words, they must be completely pure. Due to an implementation limitation, this also means that they currently cannot define a closure or generator.

I think you’re idea is making a closure so it’s illegal, although the compiler won’t tell you.

Edit: not sure why I thought it makes a closure but this idea might mutate some global state or violate other rules like “Generated functions are only permitted to call functions that were defined before the definition of the generated function”. Seems at least risky to me. Anyway don’t take me as an authority, I know very little about generated functions so I just try to be careful.

You can run a small piece to see what it does:

case_expr(idx, n) = idx == n ?
    :(foos[$idx]) :
    :(i == $idx ? foos[$idx] : $(case_expr(idx+1, n)))

julia> case_expr(1, 4)
:(if i == 1
      foos[1]
  else
      if i == 2
          foos[2]
      else
          if i == 3
              foos[3]
          else
              foos[4]
          end
      end
  end)

It’s building the same expression as if I wrote the following:

julia> :(i == 1 ? foos[1] :
         i == 2 ? foos[2] :
         i == 3 ? foos[3] :
                  foos[4])
:(if i == 1
      foos[1]
  else
      if i == 2
          foos[2]
      else
          if i == 3
              foos[3]
          else
              foos[4]
          end
      end
  en

Hmm… Interesting. Though, I have not see any issues when I use the above code, I guess a better solution would be to pass foos as a parameter:

case_expr(foos, idx, n) = idx == n ?
    :(foos[$idx]) :
    :(i == $idx ? foos[$idx] : $(case_expr(foos, idx+1, n)))

And thank you for the detailed explanation.

This doesn’t make a difference: when you write :(foos[$idx]) it’s a literal foos in the expression, it doesn’t use the foos argument.

You could define something like this:

case_expr(name::Val{Name}, idx, n) where Name = idx == n ?
           :($Name[$idx]) :
           :(i == $idx ? $Name[$idx] : $(case_expr(name, idx+1, n)))

julia> case_expr(Val(:XX), 1, 2)
:(if i == 1
      XX[1]
  else
      XX[2]
  end)

I’m using Val because at the time when the expression is calculated by the generated function, only the types of arguments are known.

But if I don’t need this flexibility I would rather keep the first version with foos hardcoded.

Ok. Got it.