Custom broadcasting for static, immutable type

Consider the following type with two fields,

using StaticArrays

struct Foo{T,N}
    x::T
    y::SVector{N,T}
end

The objective is to

  1. make broadcasting work for matching dimension N and scalars,
  2. error for everything else

Eg the following should work:

x = Foo(1, SVector(2, 3))
x .+ 3
(x .+ x) ./ 2

but eg x .+ 1:3 should not.

Based on the manual, I implemented this as

Broadcast.broadcastable(x::Foo) = x

struct FooStyle{N} <: Broadcast.BroadcastStyle end

Broadcast.BroadcastStyle(::Type{Foo{T,N}}) where {T,N} = FooStyle{N}()

Broadcast.BroadcastStyle(::FooStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where N = FooStyle{N}()

function Broadcast.materialize(B::Broadcast.Broadcasted{FooStyle{N}}) where N
    flat = Broadcast.flatten(B)
    args = flat.args
    f = flat.f
    xs = map(a -> a isa Foo ? a.x : Ref(a), args)
    ys = map(a -> a isa Foo ? a.y : Ref(a), args)
    Foo(f.(xs...), f.(ys...))
end

It seems to work — I am kind of surprised how easy it was. Is this all I need, or did I miss anything?

4 Likes

Thanks a lot ! This allowed me to set up broadcast with a wrapper around an SVector as well.

materialize disappeared from the documentation,
which seems focused on mutable types.
But after looking into broadcast.jl, it looks like the Ref(a) are not needed (a is enough):

function Broadcast.materialize(B::Broadcast.Broadcasted{FooStyle{N}}) where N
	    flat = Broadcast.flatten(B)
	    args = flat.args
	    f = flat.f
	    xs = map(a -> a isa Foo ? a.x : a, args)
	    ys = map(a -> a isa Foo ? a.y : a, args)
	    Foo(f.(xs...), f.(ys...))
	end

Otherwise some operations like convert.(Int32, x) fails.

Or is there a catch ?