Customizing Broadcasting Recursively for Collections of Arrays

I previously asked this question and am aware of the section in the manual about customizing broadcasting.

However, rather than defining a new type which behaves like a single array, I would like to define broadcasting recursively for types which contain several arrays (or types which are themselves collections of arrays). That is, imagine we have

struct A{T so that broadcasting is already defined for T}
a::T
b::T
end

If x::A, y::A, calculate z = x .+ y by calling z.a = x.a .+ y.a (etc.).

This would allow me to efficiently express algebra with large models I am building. Any ideas on how to do this?

1 Like

Extending upon your previous solution, youā€™ll need to now implement broadcasting for your style. You do this by implementing copyto!(dest::A, bc::Broadcast.Broadcasted{MyStyle}). There are two challenges here:

  1. Implementing that method isnā€™t easy. You need to support any arbitrary combinations of fused expressions and do the indexing computations yourself ā€” you can make this a bit easier (at the sacrifice of a bit of performance) by using Broadcast.flatten, which transforms the nested expression tree into a single function and a flat list of arguments. Itā€™s still hard and annoying.

  2. Perhaps more importantly, youā€™re at the edge of (or possibly beyond) supported broadcasting use-cases. Broadcast is defined as doing an element-wise computation, but what is an ā€œelementā€ of A? See Broadcasting in 0.7 and https://github.com/JuliaLang/julia/issues/27988#issuecomment-403319535 for more details.

1 Like

Thank you for your response!

My thinking and use case is similar to marius311 in the GitHub link. Essentially, I am just looking for a way to call the already implemented broadcasting rules of the fields of my type, using the dot syntax for my type.

What I got out of this is now:

  1. What I want to do is currently not considered broadcasting even though it could be phrased as such (see marius311ā€™s way of supporting indexing in his Foo type). However, such indexing or flattening would incur a performance hit.

  2. If I implement broadcast, broadcast! methods for my types, I can get what I want at the expense of notational ease and readability.
    This could look like

function broadcast!(f, dest::A, x::A, y::A)
@. dest.a = f(x.a, y.a)
@. dest.b = f(x.b, y.b)
end

Since this came up before, I am wondering if this use case could ever be supported by broadcasting. Given that it is easy to implement the broadcast methods in this case, could they be as easily translated to a dot syntax for A?

In any case, thank you for your help!

The trouble is that itā€™s at odds with the definition of broadcast. How should we define broadcast if it doesnā€™t mean an elementwise operation?

The dot syntax needs to handle arbitrary heterogeneous combinations of arguments with an arbitrary number of nested fused calls. Sure, itā€™s nicely defined for a simple function over a pair of A arguments, but it gets complicated fast when you add in other arrays, scalars, tuples, etc., etc. And since youā€™re defining broadcast to mean something different than what its generic fallbacks do, if your specialization ever fails to kick in you may get a completely different answer.

2 Likes

I now better understand the difficulty in this. Itā€™s not clear how to deal with heterogeneous combinations of arguments in this case. Thank you for explaining it to me!

1 Like

Iā€™ve actually been using the method described there quite successfully. For fairly simple expressions which donā€™t hit the bug mentioned therein, its completely type-stable with no performance overhead, its exactly as if I wrote the broadcast over the member arrays by hand. Even in cases where you do hit that inference bug, you can just write out the flattened expression by hand. E.g. this expression hits that bug

@. y += h*(kā‚ + 2kā‚‚ + 2kā‚ƒ + kā‚„)/6

but in cases like that you can just write out:

broadcast!((y,h,kā‚,kā‚‚,kā‚ƒ,kā‚„)->(y+h*(kā‚+2kā‚‚+2kā‚ƒ+kā‚„)/6), y, (y,h,kā‚,kā‚‚,kā‚ƒ,kā‚„)...)

and youā€™re back to Array performance. You could even write a macro to do that for you in fact. It is a tiny bit more of a pain since you have to catch these cases where inference fails, so I do really hope the developers can one day fix #27988, but for now this is OK.

2 Likes

FYI, in terms of automatically broadcasting things on the right hand side, itā€™s kind of easy to get automatic broadcasting for arbitrary (ā€œwell-behavingā€) struct. To do this, you can just subtype the following BroadcastableStruct (which is just a subset of BroadcastableCallable I described here):

abstract type BroadcastableStruct end

fieldvalues(obj) = ntuple(i -> getfield(obj, i), fieldcount(typeof(obj)))

# Taken from `Setfield.constructor_of`:
@generated constructor_of(::Type{T}) where T =
    getfield(parentmodule(T), nameof(T))

Broadcast.broadcastable(obj::BroadcastableStruct) =
    Broadcast.broadcasted(
        constructor_of(typeof(obj)),
        map(Broadcast.broadcastable, fieldvalues(obj))...)

Usage:

struct A{T} <: BroadcastableStruct
    a::T
    b::T
end

f(x, y) = A(x.a * y.a, x.b * y.b)
g(x::A) = x.a + x.b

g.(f.(A(ones(3), ones(3)), A(ones(3), ones(3))))

Result:

3-element Array{Float64,1}:
 2.0
 2.0
 2.0

Of course, you need to make sure that the construction of the struct is ā€œfreeā€ (or define such one and use it via constructor_of). Itā€™s just a convenient way to use broadcasting so I guess it does not improve any performance, though (I havenā€™t checked how it works with the inference etc).

I donā€™t think itā€™s hard to define copyto!(::A{<:AbstractArray{T}}, bc::Broadcasted) for the case ā€œeltypeā€ of bc is A{T}. But this is a bit more tricky to do at the level of BroadcastableStruct.

2 Likes

Thank you for your response Marius, that is helpful!

This is an implementation of recursive broadcasting compatible with scientific computing libraries like DifferentialEquations.jl

4 Likes

The approach used in MultiScaleArrays.jl works well for my use case and is type stable. Thanks for sharing!