Multiply vector elementwise by [1, v, v, ...]

I have a calculation where I need to multiple various vectors elementwise by [1, v, v, ...].

The inputs may be any kind of vector and I want code to adapt accordingly, ie the input x is an SVector, the output should also be an SVector, without allocations, but if it is a Vector I want a Vector, etc. I also want to spend minimal effort on this and not handle a ton of special cases.

Basically, I want x .* y to yield what it would yield if y was a vector “like” y, filled according to the pattern above.

Inspired by FillArrays.jl, I came up with something like

struct AfterOne{T,A<:AbstractUnitRange} <: AbstractVector{T}
    v::T
    axis::A
end
AfterOne(v, n::Int) = AfterOne(v, Base.OneTo(n))
Base.axes(a::AfterOne) = (a.axis,)
Base.size(a::AfterOne) = (length(a.axis),)
Base.getindex(a::AfterOne{T}, i) where T = i == firstindex(a.axis) ? one(T) : a.v

using StaticArrays
s = SVector(1, 3, 5.0)
a = AfterOne(3.0, axes(s, 1))
z = a .* s 

but z is still not an SVector (it is a SizedVector). How can I fix that? Note that

map(*, s, a)

works fine, but I want to hook into the broadcasting machinery.

The way to do this is to add BroadcastStyle, and Size trait methods that detect the presence of that SOneTo that comes from axes(::SVector):

StaticArrays.Size(::AfterOne{T, SOneTo{N}}) where {T, N} = Size{(N,)}()
Base.Broadcast.BroadcastStyle(::Type{AfterOne{T, SOneTo{N}}}) where {T, N} = StaticArrays.StaticArrayStyle{1}()

and then voila

julia> z = a .* s
3-element SVector{3, Float64} with indices SOneTo(3):
  1.0
  9.0
 15.0

Overloading Size is documented, but StaticArrayStyle unfortunately isn’t so maybe we should open an issue for that.

3 Likes

Note that SA[1,2,3] .* fill!(similar(SA[1,2,3]), 1) makes an MVector but returns an SVector. Which suggests doing this:

julia> function mulafter(x::AbstractVector, v::Number)
         y = fill!(similar(x), v)
         y[begin: begin] .= 1
         x .* y
       end;

julia> mulafter([1,2,3], 4)
3-element Vector{Int64}:
  1
  8
 12

julia> @btime mulafter($(SA[1,2,3]), 4)  # MVector does not escape
  2.416 ns (0 allocations: 0 bytes)
3-element SVector{3, Int64} with indices SOneTo(3):
  1
  8
 12

Could be more careful about promoting eltypes. Allocates two vectors not one for Array case. But much simpler than defining & testing an array struct.

1 Like