What's the "right" way to broadcast vector-valued functions?

I’m trying to figure out the “right” way to combine broadcasting with functions for vector-valued inputs/outputs. For example, the following function defines foo through broadcasted operations.

using LinearAlgebra

function foo(a,b)
    return a.*b, a.+b
end

c,d = foo(collect(1:3),collect(0:2))
  • Is there a way to perform the same operation using the non-broadcast version of foo? Using foo.(a,b) with the scalar version of foo returns an array of tuples, instead of the tuple of arrays in the broadcast version.
  • Is it OK to just define all functions using broadcast operations (so that they work for both scalar and vectorized inputs), or will there be some performance hit?

Thanks!

julia> function foo(a, b)
           a .* b, a .+ b
       end
foo (generic function with 1 method)

julia> bar(a, b) = a * b, a + b
bar (generic function with 1 method)

julia> foo(2.3, 4.5)
(10.35, 6.8)

julia> bar(2.3, 4.5)
(10.35, 6.8)
julia> @code_llvm optimize=true debuginfo=:none foo(2.3, 4.5)

define void @julia_foo_17899([2 x double]* noalias nocapture sret, double, double) {
top:
  %3 = fmul double %1, %2
  %4 = fadd double %1, %2
  %.sroa.0.0..sroa_idx = getelementptr inbounds [2 x double], [2 x double]* %0, i64 0, i64 0
  store double %3, double* %.sroa.0.0..sroa_idx, align 8
  %.sroa.2.0..sroa_idx1 = getelementptr inbounds [2 x double], [2 x double]* %0, i64 0, i64 1
  store double %4, double* %.sroa.2.0..sroa_idx1, align 8
  ret void
}

julia> @code_llvm optimize=true debuginfo=:none bar(2.3, 4.5)

define void @julia_bar_17908([2 x double]* noalias nocapture sret, double, double) {
top:
  %3 = fmul double %1, %2
  %4 = fadd double %1, %2
  %.sroa.0.0..sroa_idx = getelementptr inbounds [2 x double], [2 x double]* %0, i64 0, i64 0
  store double %3, double* %.sroa.0.0..sroa_idx, align 8
  %.sroa.2.0..sroa_idx1 = getelementptr inbounds [2 x double], [2 x double]* %0, i64 0, i64 1
  store double %4, double* %.sroa.2.0..sroa_idx1, align 8
  ret void
}

They’ll have the same runtime performance.

1 Like

If you want to use only the scalar version of foo, you have to destruct the Array{Tuple{Int64,Int64},1} returned by foo_sc.(a, b), e.g. like this:

julia> function foo_sc(a, b)
           a*b, a+b
       end
foo_sc (generic function with 1 method)

julia> function foo_bc(a, b)
           a.*b, a.+b
       end
foo_bc (generic function with 1 method)

julia> a, b = collect(1:3), collect(0:2)
([1, 2, 3], [0, 1, 2])

julia> function foo_destruct(a, b)
           tmp = foo_sc.(a, b)
           last.(tmp), first.(tmp)
       end
foo_destruct (generic function with 1 method)

julia> foo_bc(a, b)
([0, 2, 6], [1, 3, 5])

julia> foo_destruct(a, b)
([1, 3, 5], [0, 2, 6])

Since the data is traversed multiple times, you pay some performance penalty:

julia> using BenchmarkTools

julia> @benchmark foo_bc($a, $b)
BenchmarkTools.Trial: 
  memory estimate:  256 bytes
  allocs estimate:  3
  --------------
  minimum time:     57.775 ns (0.00% GC)
  median time:      61.215 ns (0.00% GC)
  mean time:        74.469 ns (13.24% GC)
  maximum time:     38.152 μs (99.73% GC)
  --------------
  samples:          10000
  evals/sample:     983

julia> @benchmark foo_destruct($a, $b)
BenchmarkTools.Trial: 
  memory estimate:  384 bytes
  allocs estimate:  4
  --------------
  minimum time:     79.169 ns (0.00% GC)
  median time:      81.857 ns (0.00% GC)
  mean time:        97.842 ns (12.35% GC)
  maximum time:     38.397 μs (99.64% GC)
  --------------
  samples:          10000
  evals/sample:     969

As shown be @Elrod, you won’t have a performance penalty if you use the broadcasted function for simple scalar arguments.
If I remember correctly, there has been a performance penalty for using broadcasting with StaticArrays in OrdinaryDiffEq. That’s why the non-inplace versions don’t use broadcasting while the inplace versions do use it, e.g.


vs.

1 Like

Thanks @Elrod and Hendrik!

Found a related question from 2018: Destructuring and broadcast. The motivation was similar (avoid storage and allocation of intermediate terms).

Adding broadcasting to tuple functions to Base is still under development. In the meantime, I’ve found that Destruct.jl is fast and achieves what I’m looking for. The “unzip” function from this StackExchange post

unzip(a) = map(x->getfield.(a, x), fieldnames(eltype(a)))
x = [1 2 3]
y = [2 3 4]
bar = (x,y)->(y,x)
f,g = unzip(bar.(x,y)) # usage

also seems pretty efficient.

1 Like