How to vectorize only one argument in function call

Say I have a function that accepts several vectors f(a,b,c), but I have a vector of inputs for one of the arguments a. Can I call f like f(.a,b,c) such that f will vectorize only over a.
Since if I call f.(a,b,c) f vectorizes over a, b, and c and returns nonsense.

I’ve seen

f.(a,Ref(b),Ref(c))

1 Like

You can protect arguments against broadcasting by wrapping them in a container. You should probably use Ref or a tuple, like this:

f.(a, Ref(b), Ref(c))
f.(a, (b,), (c,))

You could even use a vector:

f.(a, [b], [c]) 

though that would be less efficient.

When you wrap it in an outer container, broadcasting happens over the outer layer, leaving the contents as is.

5 Likes

If b and c are scalars, broadcast works the way you hope:

julia> f(a, b, c) = b*a .+ c
f (generic function with 1 method)

julia> f.(1:3, 1, 0)
3-element Vector{Int64}:
 1
 2
 3

If one is a vector and you need to explicity treat it as a scalar, you can wrap it in a Ref

julia> f.(1:3, 1, [0, 1])
ERROR: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 3 and 2
<... omitted stack trace here ...>

julia> f.(1:3, 1, Ref([0, 1]))
3-element Vector{Vector{Int64}}:
 [1, 2]
 [2, 3]
 [3, 4]

Why is it less efficient to wrap it in a vector?

A vector allocates memory on the heap with a pointer to your array, and must eventually be garbage collected. It’s not so much, but it’s an unnecessary waste that can sometimes grow into a significant cost if it happens often enough, or which may stop some compiler optimizations from happening.

Wrapping it in, say, a tuple, is essentially zero-cost.

1 Like

This is one area in which I wish I could use JAX’s notation for vmap i.e.
vmap(f)(a, b, c) which by default broadcasts f over the first (axis of the first) input, or explicitly I could do
vmap(f, (0, 1, None))(a, b, c) which broadcasts f over the first axis of a, the second axis of b, and doesn’t broadcast over c at all.

2 Likes

How do you designate broadcasting over all axes of an argument?

That’s actually restriction of JAX: you need to write one vmap per dimension of array to map over.

For my usecase, I actually wrote a helper function called antivmap which acts as a vmap over all but the specified axes and I’ve found it super useful