Custom broadcasting for static, immutable type

Consider the following type with two fields,

using StaticArrays

struct Foo{T,N}
    x::T
    y::SVector{N,T}
end

The objective is to

  1. make broadcasting work for matching dimension N and scalars,
  2. error for everything else

Eg the following should work:

x = Foo(1, SVector(2, 3))
x .+ 3
(x .+ x) ./ 2

but eg x .+ 1:3 should not.

Based on the manual, I implemented this as

Broadcast.broadcastable(x::Foo) = x

struct FooStyle{N} <: Broadcast.BroadcastStyle end

Broadcast.BroadcastStyle(::Type{Foo{T,N}}) where {T,N} = FooStyle{N}()

Broadcast.BroadcastStyle(::FooStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where N = FooStyle{N}()

function Broadcast.materialize(B::Broadcast.Broadcasted{FooStyle{N}}) where N
    flat = Broadcast.flatten(B)
    args = flat.args
    f = flat.f
    xs = map(a -> a isa Foo ? a.x : Ref(a), args)
    ys = map(a -> a isa Foo ? a.y : Ref(a), args)
    Foo(f.(xs...), f.(ys...))
end

It seems to work — I am kind of surprised how easy it was. Is this all I need, or did I miss anything?

2 Likes