Outer broadcasting?

Hi,

I would like to do something like the outer product for general operations. Basically, I f I have a function

f(x1, x2, [...], xn) = [...]

and I provide it with n arrays a1an as input, then I would like to apply f on all combination of input elements, such that the resulting array has the shape:

outer(f, a1, a2, [...], an) |> size == (size(a1)..., size(a2)..., [...], size(an)...)

How can I do that in a clean and performant way?

Example:

julia> f(a, b, c) = a+b+c
f (generic function with 1 method)

julia> x = 1:300; y = 1:400; z = 1:500;

# performant but super hard to read
julia> f.(x, reshape(y, (ones(Int64, ndims(x))..., size(y)...)), reshape(z, (ones(Int64, ndims(x) + ndims(y))..., size(z)...))) |> size == (size(x)..., size(y)..., size(z)...)
true

julia> @benchmark f.(x, reshape(y, (ones(Int64, ndims(x))..., size(y)...)), reshape(z, (ones(Int64, ndims(x) + ndims(y))..., size(z)...)))
BenchmarkTools.Trial: 
  memory estimate:  457.76 MiB
  allocs estimate:  13
  --------------
  minimum time:     141.622 ms (0.29% GC)
  median time:      150.833 ms (9.09% GC)
  mean time:        158.361 ms (10.34% GC)
  maximum time:     193.617 ms (22.94% GC)
  --------------
  samples:          32
  evals/sample:     1

# easy to read but slow
julia> broadcast( u -> f(u...), Base.Iterators.product(x, y, z)) |> size == (size(x)..., size(y)..., size(z)...)
true

julia> @benchmark broadcast( u -> f(u...), Base.Iterators.product(x, y, z))
BenchmarkTools.Trial: 
  memory estimate:  1.79 GiB
  allocs estimate:  6
  --------------
  minimum time:     731.471 ms (1.30% GC)
  median time:      799.173 ms (13.66% GC)
  mean time:        792.291 ms (11.79% GC)
  maximum time:     820.479 ms (11.84% GC)
  --------------
  samples:          7
  evals/sample:     1

The first version seems performant, the second one seems clean. Is there a standard function that I am overlooking?

Iterators.product(arrays...) returns an outer product iterator like you’re looking for:

julia> a = 1:4;
julia> b = [(x,y) for x=5:7, y=8:9]
3×2 Array{Tuple{Int64,Int64},2}:
 (5, 8)  (5, 9)
 (6, 8)  (6, 9)
 (7, 8)  (7, 9)

julia> collect(Iterators.product(a,b))
4×3×2 Array{Tuple{Int64,Tuple{Int64,Int64}},3}:
[:, :, 1] =
 (1, (5, 8))  (1, (6, 8))  (1, (7, 8))
 (2, (5, 8))  (2, (6, 8))  (2, (7, 8))
 (3, (5, 8))  (3, (6, 8))  (3, (7, 8))
 (4, (5, 8))  (4, (6, 8))  (4, (7, 8))

[:, :, 2] =
 (1, (5, 9))  (1, (6, 9))  (1, (7, 9))
 (2, (5, 9))  (2, (6, 9))  (2, (7, 9))
 (3, (5, 9))  (3, (6, 9))  (3, (7, 9))
 (4, (5, 9))  (4, (6, 9))  (4, (7, 9))

So you can do

outer(f,arrays...) = [f(x...) for x in Iterators.product(arrays...)]
3 Likes

Thanks a lot, @yha. You basically suggested the same as my second solution, just with a liste comprehension, right? And seemingly your idea is way faster than mine with the braodcasting, so I am going to use yours :smile:

But do you know why the list comprehension is so much faster than the broadcasting? I didn’t expect that.

julia> f(a, b, c) = a+b+c
f (generic function with 1 method)

julia> x = 1:300; y = 1:400; z = 1:500;

julia> using BenchmarkTools

julia> @benchmark broadcast( u -> f(u...), Base.Iterators.product(x, y, z))
BenchmarkTools.Trial: 
  memory estimate:  1.79 GiB
  allocs estimate:  6
  --------------
  minimum time:     733.547 ms (7.83% GC)
  median time:      760.383 ms (12.34% GC)
  mean time:        765.740 ms (12.62% GC)
  maximum time:     797.509 ms (15.32% GC)
  --------------
  samples:          7
  evals/sample:     1

julia> @benchmark [f(q...) for q in Iterators.product(x, y, z)]
BenchmarkTools.Trial: 
  memory estimate:  457.76 MiB
  allocs estimate:  5
  --------------
  minimum time:     192.790 ms (6.51% GC)
  median time:      199.686 ms (6.35% GC)
  mean time:        202.157 ms (7.98% GC)
  maximum time:     238.016 ms (19.89% GC)
  --------------
  samples:          25
  evals/sample:     1

Another solution, which is faster than both of the options presented so far, is to use broadcasting directly, without any need for Iterators.product:

julia> @btime f.($x, reshape($y, (1, :)), reshape($z, (1, 1, :)));
  114.924 ms (3 allocations: 457.76 MiB)
2 Likes

Thanks @rdeits. Actually, my first version is the generalization of your suggestion. And on my computer they run both with the same speed, I assume your cpu is faster :smile:

But again, the moment one generalizes this idea, it becomes rather hard to read.

julia> f(a, b, c) = a+b+c
f (generic function with 1 method)

julia> x = 1:300; y = 1:400; z = 1:500;

julia> using BenchmarkTools

julia> @benchmark f.(x, reshape(y, (ones(Int64, ndims(x))..., size(y)...)), reshape(z, (ones(Int64, ndims(x) + ndims(y))..., size(z)...)))
BenchmarkTools.Trial: 
  memory estimate:  457.76 MiB
  allocs estimate:  13
  --------------
  minimum time:     146.591 ms (8.67% GC)
  median time:      162.362 ms (8.05% GC)
  mean time:        193.845 ms (10.20% GC)
  maximum time:     367.354 ms (17.39% GC)
  --------------
  samples:          26
  evals/sample:     1

julia> @benchmark f.(x, reshape(y, (1, :)), reshape(z, (1, 1, :)))
BenchmarkTools.Trial: 
  memory estimate:  457.76 MiB
  allocs estimate:  7
  --------------
  minimum time:     147.229 ms (8.59% GC)
  median time:      150.185 ms (8.76% GC)
  mean time:        155.161 ms (10.41% GC)
  maximum time:     182.251 ms (26.30% GC)
  --------------
  samples:          33
  evals/sample:     1

Whoops, sorry, I missed that :slight_smile:

In that case, this seems like a great time for some metaprogramming! How does this look?

julia> @generated function outer(f, args::Tuple{Vararg{Any, N}}) where {N}
         expr = Expr(:call, :broadcast, :f)
         for i in 1:N
           if i == 1
             push!(expr.args, :(args[1]))
           else
             push!(expr.args, Expr(:call, :reshape, :(args[$i]), Expr(:tuple, [1 for _ in 1:i - 1]..., :(:))))
           end
         end
         expr
       end
outer (generic function with 1 method)

julia> @btime outer($f, ($x, $y, $z));
  108.493 ms (3 allocations: 457.76 MiB)
1 Like

The list comprehension over Iterators.product is faster than broadcasting because broadcasting requires index-ability. Since Iterators.product is a simple iterable, broadcast makes a copy of it with collect before doing its thing.

This is something I do think it’d be nice to have a better solution for — we’ve been brainstorming an API for it on the back-burner. Suggestions have included things like f.(orthogonalize(x, y, z)...) or f.(x, ^y, ^z).

3 Likes

Yes, apparently I skimmed your post too quickly, because I missed the fact that you were already using Iterators.product in your second solution.

Can’t Iterators.product itself be made indexable when its inputs are? It already has shape when its input have shape – an unusual property for iterators outside the Julia world – so it doesn’t seem like a big leap.
Though maybe the fact that it’s in Iterators would make it less discoverable, since it’s somewhat surprising that “iterators” have shape and/or are indexable.

1 Like

Yup, it sure could. It’d be as simple as defining axes, getindex, and Broadcast.broadcastable (to behave as itself) for Base.Iterators.ProductIterator{Tuple{Vararg{AbstractArray}}}.

1 Like

Then maybe I will suggest that. I generally like to use broadcasting so julia can collapse some of the loops when concatenating multiple functions.

I also like the idea of orthogonolize or the similar. As you said you have been brainstorming: Is there an issue on that?

It’s a bit of a tangent to the actual issue, but we’ve been talking about it in What to do about nonscalar indexing? · Issue #30845 · JuliaLang/julia · GitHub.

1 Like