Say I have a function that accepts several vectors f(a,b,c)
, but I have a vector of inputs for one of the arguments a
. Can I call f like f(.a,b,c)
such that f will vectorize only over a
.
Since if I call f.(a,b,c)
f vectorizes over a
, b
, and c
and returns nonsense.
I’ve seen
f.(a,Ref(b),Ref(c))
You can protect arguments against broadcasting by wrapping them in a container. You should probably use Ref
or a tuple, like this:
f.(a, Ref(b), Ref(c))
f.(a, (b,), (c,))
You could even use a vector:
f.(a, [b], [c])
though that would be less efficient.
When you wrap it in an outer container, broadcasting happens over the outer layer, leaving the contents as is.
If b
and c
are scalars, broadcast works the way you hope:
julia> f(a, b, c) = b*a .+ c
f (generic function with 1 method)
julia> f.(1:3, 1, 0)
3-element Vector{Int64}:
1
2
3
If one is a vector and you need to explicity treat it as a scalar, you can wrap it in a Ref
julia> f.(1:3, 1, [0, 1])
ERROR: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 3 and 2
<... omitted stack trace here ...>
julia> f.(1:3, 1, Ref([0, 1]))
3-element Vector{Vector{Int64}}:
[1, 2]
[2, 3]
[3, 4]
Why is it less efficient to wrap it in a vector?
A vector allocates memory on the heap with a pointer to your array, and must eventually be garbage collected. It’s not so much, but it’s an unnecessary waste that can sometimes grow into a significant cost if it happens often enough, or which may stop some compiler optimizations from happening.
Wrapping it in, say, a tuple, is essentially zero-cost.
This is one area in which I wish I could use JAX’s notation for vmap
i.e.
vmap(f)(a, b, c)
which by default broadcasts f
over the first (axis of the first) input, or explicitly I could do
vmap(f, (0, 1, None))(a, b, c)
which broadcasts f
over the first axis of a
, the second axis of b
, and doesn’t broadcast over c
at all.
How do you designate broadcasting over all axes of an argument?
That’s actually restriction of JAX: you need to write one vmap
per dimension of array to map over.
For my usecase, I actually wrote a helper function called antivmap
which acts as a vmap
over all but the specified axes and I’ve found it super useful