Implementing broadcasting for a nested `Vector` with a scalar

Say I have a struct one of whose members has the type

RaggedVector{T} = Vector{Vector{T}}

And if we construct two conformal instances of this type

xs = [[1, 2, 3],
      [4],
      [5, 6]]
ys = [[6, 5, 4],
      [3],
      [2, 1]]

Then the two can be broadcasted together without problems

julia> zs = xs .+ ys
3-element Vector{Vector{Int64}}:
 [7, 7, 7]
 [7]
 [7, 7]

But, when one of the operands is a scalar, the materialization of the broadcast results in <:Vector + <:Number, which does not work.

julia> zs = xs .+ 2.0
ERROR: MethodError: no method matching +(::Vector{Int64}, ::Float64)
...

So, I need to make nested vector-scalar broadcast work.

I went through the customizing broadcast section of the docs and tried to follow the functions and types in the call graph of broadcasting non-standard types like sparse arrays, and it seems that I need to create a new broadcast style that results from the combination of the two operand broadcast styles along with a dispatch on materialize, but I am struggling to figure it out completely.


Edit: I am still working on making this work using the suggestions here, and will post the final working code. It might not be “production worthy” though because most likely I will have to settle for type piracy.

1 Like

You can add a line at the beginning of your file extending the operator

import Base: +
+(xs::Vector{Q}, fl::T) where Q <: Real where T <: Real = xs .+ fl

To extend the + operator for your needs. Observe that you need to import the operator in order to extend it.

You could define

dotplus(x, y) = x .+ y

Broadcasted dotplus will then work with both scalars and nested vectors.

julia> dotplus.(xs, ys)
3-element Vector{Vector{Int64}}:
 [7, 7, 7]
 [7]
 [7, 7]

julia> dotplus.(xs, 2.0)
3-element Vector{Vector{Float64}}:
 [3.0, 4.0, 5.0]
 [6.0]
 [7.0, 8.0]
3 Likes

Note that this is technically type piracy.

2 Likes

I just realised: these days can also do (.+).(xs, 2.0) without defining dotplus.

7 Likes

the previous solutions are valid if T<:Real (In practice, with only two levels of nesting.)
In the more general case where T=Any something like this might work

pnest!(xs::Int,n)=xs+n
pnest!(xs::Vector{Int},n)=xs.+n
pnest!(xs,n)=[xs[i]=pnest!(xsi,n) for (i,xsi) in enumerate(xs)]

the not mutating version

julia> xs
3-element Vector{Vector{Any}}:
 [1, 2, 3]
 [4]
 [5, Any[6, [7, 8]]]

julia> 

       pnest(xs::Int,n)=xs+n
       pnest(xs::Vector{Int},n)=xs.+n
       pnest(xs,n)=[pnest(xsi,n) for xsi in xs]
pnest (generic function with 3 methods)

julia> 

julia> 

julia> pnest(xs,-11)
3-element Vector{Vector}:
 [-10, -9, -8]
 [-7]
 Any[-6, Any[-5, [-4, -3]]]

julia> pnest(xs,-11)
3-element Vector{Vector}:
 [-10, -9, -8]
 [-7]
 Any[-6, Any[-5, [-4, -3]]]

Thank you, this works the intended way, but I want to avoid type piracy as @fatteneder pointed out.

The solution by @Per is best here, but in my own code I often use the equivalent:

broadcast(.+, xs,  2) == (.+).(xs, 2)

Sometimes this is preferable as you can more easily pass in an anonymous function if you’re doing anything a little more involved than just addition, and also in my own code those extra . start to get a little too much for my liking. YYMV, of course.

1 Like