Consider the following type with two fields,
using StaticArrays
struct Foo{T,N}
x::T
y::SVector{N,T}
end
The objective is to
- make broadcasting work for matching dimension
N
and scalars, - error for everything else
Eg the following should work:
x = Foo(1, SVector(2, 3))
x .+ 3
(x .+ x) ./ 2
but eg x .+ 1:3
should not.
Based on the manual, I implemented this as
Broadcast.broadcastable(x::Foo) = x
struct FooStyle{N} <: Broadcast.BroadcastStyle end
Broadcast.BroadcastStyle(::Type{Foo{T,N}}) where {T,N} = FooStyle{N}()
Broadcast.BroadcastStyle(::FooStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where N = FooStyle{N}()
function Broadcast.materialize(B::Broadcast.Broadcasted{FooStyle{N}}) where N
flat = Broadcast.flatten(B)
args = flat.args
f = flat.f
xs = map(a -> a isa Foo ? a.x : Ref(a), args)
ys = map(a -> a isa Foo ? a.y : Ref(a), args)
Foo(f.(xs...), f.(ys...))
end
It seems to work — I am kind of surprised how easy it was. Is this all I need, or did I miss anything?