Only Specializing Broadcast on Likewise Types

broadcast

#1

Okay, so I have this broadcast overload to do some crazy stuff on my array wrapper type A:

struct A <: AbstractVector{Float64}
    a::Vector{Float64}
end
Base.getindex(A::A,i...) = A.a[i...]
Base.size(A::A) = size(A.a)
Base.size(A::A,i::Int) = size(A.a,i)
Base.similar(a::A) = A(similar(a.a))

Base.BroadcastStyle(::Type{A}) = Broadcast.ArrayStyle{A}()
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{A}})
    first_a = find_a(bc)
    out = similar(first_a)
    copyto!(out,bc)
    out.a .*= 2
    out
end

@inline function Base.copyto!(dest::A, bc::Broadcast.Broadcasted{Nothing})
    copyto!(dest.a,unpack(bc))
    dest.a .*= 2
    dest
end

@inline unpack(bc::Broadcast.Broadcasted) = Broadcast.Broadcasted(bc.f, unpack_args(bc.args))
unpack(x,::Any) = x
unpack(x::A) = x.a
@inline unpack_args(args::Tuple) = (unpack(args[1]), unpack_args(Base.tail(args))...)
unpack_args(args::Tuple{Any}) = (unpack(args[1]),)
unpack_args(::Any, args::Tuple{}) = ()

find_a(bc::Base.Broadcast.Broadcasted) = find_a(bc.args)
find_a(args::Tuple) = find_a(find_a(args[1]), Base.tail(args))
find_a(x) = x
find_a(a::A, rest) = a
find_a(::Any, rest) = find_a(rest)

And we can see it works:

# Verify it works
a1 = A(ones(3))
a2 = A([1.0,2.0,3.0])
a1 .= a2
a1 .* a2

Now I want to make it able to mix other standard broadcast objects in there. If I define

Base.BroadcastStyle(a::Broadcast.ArrayStyle{AMSA}, b::Base.Broadcast.DefaultArrayStyle) = b

Then yes, standard arrays will work.

a1 .* [3.0,4.0,6.0]

But not the cool stuff. For example, on this weird broadcasting thing, arrays work:

broadcasting_randn = Base.Broadcast.Broadcasted(randn,())
a3 = zeros(3)
a3 .= broadcasting_randn # all different random numbers!

but now on my A it doesn’t broadcast correctly:

a1 .= broadcasting_randn # BoundsError

So how do I better say “just go back to using indexing on not likewise types” (but include scalars)?

Note that this issue is a simplified example from https://github.com/JuliaDiffEq/MultiScaleArrays.jl/pull/36


#2

Perhaps the easiest way would be to pre-walk the broadcast tree, check for any unsupported types, and in such cases punt to the DefaultArrayStyle implementation. For example, see how SparseArrays supports just some specific cases, even though its array style is slightly greedier than it can support: