# How to vectorize only one argument in function call

Say I have a function that accepts several vectors` f(a,b,c)`, but I have a vector of inputs for one of the arguments `a`. Can I call f like` f(.a,b,c)` such that f will vectorize only over `a`.
Since if I call `f.(a,b,c)` f vectorizes over `a`, `b`, and `c` and returns nonsense.

I’ve seen

`f.(a,Ref(b),Ref(c))`

1 Like

You can protect arguments against broadcasting by wrapping them in a container. You should probably use `Ref` or a tuple, like this:

``````f.(a, Ref(b), Ref(c))
f.(a, (b,), (c,))
``````

You could even use a vector:

``````f.(a, [b], [c])
``````

though that would be less efficient.

When you wrap it in an outer container, broadcasting happens over the outer layer, leaving the contents as is.

5 Likes

If `b` and `c` are scalars, broadcast works the way you hope:

``````julia> f(a, b, c) = b*a .+ c
f (generic function with 1 method)

julia> f.(1:3, 1, 0)
3-element Vector{Int64}:
1
2
3
``````

If one is a vector and you need to explicity treat it as a scalar, you can wrap it in a `Ref`

``````julia> f.(1:3, 1, [0, 1])
ERROR: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 3 and 2
<... omitted stack trace here ...>

julia> f.(1:3, 1, Ref([0, 1]))
3-element Vector{Vector{Int64}}:
[1, 2]
[2, 3]
[3, 4]

``````

Why is it less efficient to wrap it in a vector?

A vector allocates memory on the heap with a pointer to your array, and must eventually be garbage collected. It’s not so much, but it’s an unnecessary waste that can sometimes grow into a significant cost if it happens often enough, or which may stop some compiler optimizations from happening.

Wrapping it in, say, a tuple, is essentially zero-cost.

1 Like

This is one area in which I wish I could use JAX’s notation for `vmap` i.e.
`vmap(f)(a, b, c)` which by default broadcasts `f` over the first (axis of the first) input, or explicitly I could do
`vmap(f, (0, 1, None))(a, b, c)` which broadcasts `f` over the first axis of `a`, the second axis of `b`, and doesn’t broadcast over `c` at all.

2 Likes

How do you designate broadcasting over all axes of an argument?

That’s actually restriction of JAX: you need to write one `vmap` per dimension of array to map over.

For my usecase, I actually wrote a helper function called `antivmap` which acts as a `vmap` over all but the specified axes and I’ve found it super useful