How to customize the new broadcasting infrastructure in v0.7

I want to ask about the new broadcasting operation in Julia 0.7 and the preferred way to customize it for custom types. I read the following:

https://docs.julialang.org/en/latest/manual/interfaces/#man-interfaces-broadcasting-1

I am a bit confused by the various options to stop fusion: do I override make, copy, define a new style for my custom type etc? What is the right way to define broadcasting methods for primitive functions like +, *, sin etc. specific to these types?

1 Like

Please see https://github.com/MikeInnes/TakingBroadcastSeriously.jl/issues/8

In particular, I’ve successfully stopped fusion for my TrackedArrays which should be pretty close to what you do in AutoGrad/Knet.

2 Likes

That doesn’t seem complete; presumably you want to take advantage of BroadcastStyles somewhere, to avoid games catching things like broadcast(+, ::Any, ::Any, ::TrackedArray).

Yes, there are a whole slew of new features here, and they aren’t totally orthogonal. I think it’s helpful to think of this in stages. Assume in all the code samples below that I’m using .Broadcast: broadcasted, Broadcasted.

Construction of the broadcast expression tree

The parser transforms an expression like (A .+ B) ./ C into broadcasted(/, broadcasted(+, A, B), C). By default, this constructs Broadcasted objects that just hold onto the function and tuple of arguments, but it’s using a lowercase function (instead of the constructor directly) to allow you to do something different. This means that, as dfdx suggested, you can simply overload broadcasted(::typeof(+), ::YourType, ::YourType) to either return an alternative array-like lazy representation or do the work immediately and return an intermediate array.

Sometimes, like in the case Mike mentions, you also have a BroadcastStyle promotion system setup and want to dispatch on the combined style of the arguments — often to catch that "at least one argument is a TrackedArray" dispatch problem. You can instead overload broadcasted(::YourStyle, ::typeof(+), ...). Note, though that this is just the style of the arguments passed to +, not the overall style of the entire fused expression. This is how ranges can now return ranges from broadcasting: they opt-out of fusion to compute things in O(1) when possible: base/broadcast.jl#L961-L1013.

Note that one of the arguments you get in such a broadcasted implementation could be a lazy Broadcasted expression, too! For example if C was a TrackedArray and A and B weren’t, you’d end up with a division between a Broadcasted(/, A, B) and your C. You’ll have to decide if you want to manually materialize that nested Broadcasted into a temporary before executing the division by C or if you’re able to fuse it with the division.

Execution of a broadcasted expression

If you don’t override broadcasted, Julia will create a Broadcasted representation of the expression. It computes the overall broadcast style of the entire expression by walking through all the nested Broadcasted nodes and combining them, and then it stores this as the first type parameter of the Broadcasted object. This is then copy'ed or copyto!'ed (in the case of .=), allowing you to customize copy(bc::Broadcasted{YourStyle}) or copyto!(::AbstractArray, bc::Broadcasted{YourStyle}). In these function bodies, you have the entire Broadcasted expression tree available for introspection — you can walk through the bc.args and manually decide how you want to execute the functions. Of course, that can be a lot of work and hard to get right, so you may just want to inspect the broadcast tree to see if it’s an “easy” case (like just a map-equivalent without broadcast expansion) and defer to a simpler optimized implementation. Or you can walk through the passed bc object and transform it into an equivalent Broadcasted representation if it’s at all possible as that will allow you to use its simple outer API:

for I in eachindex(bc)
    result[I] = bc[I]
end

This is how BitArray hooks into the broadcast system — it uses both aspects I mention here: it reports itself as a DefaultArrayStyle, and then in the DefaultArrayStyle implementation, we first walk through the Broadcasted expression tree to see if it’s a case where we can perform a “chunked” broadcast — that is, if we can use an implementation that operates on the UInt64 chunks instead of the individual bits (base/broadcast.jl#L841). The implementation then also does a transformation of the passed bc object to convert compatible bit-wise functions into their “chunkable” equivalents (base/broadcast.jl#L880-L886). For example, we transform the ! function (which only negates bools) to the bitwise ~ function to allow inversions to operate at the level of chunks.


So that’s a lot of detail, but I hope that helps you figure out where is best for you to latch into the system. The simple answer is broadcasted, but there may be cases where you’d want to consider the entire expression tree as a whole before doing your fusion opt-out.

11 Likes

Here is what I came up with in AutoGrad, please comment if you see anything obviously wrong or suboptimal:

  • Rec is my “tracked object” type, with the actual object stored in Rec.value.

  • For most functions (like abs) with a gradient, I want to compute the result right away and record it if one of my inputs is a Rec type, so broadcasted basically calls broadcast after stripping the Rec. broadcast_r below is the recording version of broadcast, but for this discussion assume it just strips the Rec and calls broadcast, i.e. broadcast_r(f,x) is equivalent to broadcast(f, x.value):

broadcast_r = recorder(broadcast)

broadcasted(f, x::Rec) = broadcast_r(f,x)

broadcasted(f, x::Rec, y) = broadcast_r(f,x,y)

broadcasted(f, x, y::Rec) = broadcast_r(f,x,y)

broadcasted(f, x::Rec, y::Rec) = broadcast_r(f,x,y)

  • This handles up to two arguments, I’ll have to handle functions with more arguments separately.

  • The alternative was to define specific broadcasted methods for each function:

broadcasted(::typeof(abs), x::Rec) = broadcast_r(abs,x)

But I assumed defining hundreds of “broadcasted” methods would be costly? If it is not costly, it would be cleaner.

  • For simple functions (like sign) with no gradient, I want to ignore the Rec type and do what is normally done, so I defined exceptions to the above generic broadcasted(f,…) methods:

broadcasted(f::typeof(sign), x::Rec) = broadcasted(f, x.value)

Yes, that seems reasonable, although note that you’re discarding broadcast fusion by immediately calling broadcast(f, x.value). If instead broadcast_r returns a Broadcasted(f, (x.value,)), then this will layer atop broadcast fusion (and others’ specializations thereof) more gracefully.

Another way you can more easily handle the dispatch on broadcasted with at least one Rec argument is to define a BroadcastStyle for Recs, and then instead dispatch on broadcasted(::Style{Rec}, f, args...). You need to set up a style “promotion” such that Style{Rec} always wins — or would throw an ambiguity/promotion error if it encounters another such “always wins” style. I think something like this should do the trick (untested):

Broadcast.BroadcastStyle(::Rec) = Broadcast.Style{Rec}()
Broadcast.BroadcastStyle(::Broadcast.BroadcastStyle, s::Broadcast.Style{Rec}) = s
2 Likes

I tried @mbauman’s first suggestion: keeping the results lazy:

broadcast_r = recorder(broadcasted)
broadcasted(f, x::Rec) = broadcast_r(f,x)

This way the boxed values stay as lazy Broadcasted objects longer. However when I try to take a gradient involving sum(x.*x,dims=1) I now seem to get errors like:

ERROR: MethodError: no method matching 
sum(::Base.Broadcast.Broadcasted; dims=1)

I don’t get the error without the dims argument.

Can you help me understand how/when do the Broadcasted objects materialize in the two cases of sum(x.*x) vs sum(x.*x,dims=1) and why the two cases may be different?

To partially answer my own question: Base defines materialize(x)=x for x types other than Broadcasted. This prevents arrays from materializing if they are boxed. I solved this by overriding materialize in AutoGrad. Now trying to figure out why sum(x.*x) works without this.

@mbauman sorry to bug you every 6 months with a broadcasting issue. My users started complaining that even though the above solution we discussed takes care of the primitive functions e.g. sin.(::KnetArray), it does not generalize to user defined functions. i.e. if f(x)=sin(x), then sin.(::KnetArray) works but f.(::KnetArray) does not.

In Julia 0.6 I used to do the following to take care of user defined functions:

broadcast(f,x::KnetArray) = f(Bcast(x))   # Bcast is a user defined box type
sin(Bcast(x)) = sin.(x)  # repeated for all primitive functions

In Julia 1.0 I can do the same thing replacing broadcast with broadcasted. However (1) using my own Bcast type, and (2) defining it for each primitive function seems a bit old fashioned and I have a feeling I am not taking full advantage of your new architecture. Any advice?

Yeah, that’s fair. The alternative, then, would be to discard the broadcasted methods and instead implement a BroadcastStyle and promotion such that broadcasting over KnetArrays leads to a Broadcasted{KnetStyle} object being created. that represents the operation with full generality. You then write the implementation for copy(::Broadcasted{KnetStyle}).

The pain that comes with this generality, of course, is that you then must completely implement the generic broadcast kernel, applying (and potentially fusing) the broadcast expression appropriately. We do have some tools to help you along the way, like Broadcast.flatten, but I’ve lost track of what you need to do here and why you need to opt-out in the first place, so I’m not sure it’d be helpful.

To respond to your last sentence: the issue is you need a cuda kernel defined for every different GPU array operation (unless you have the ability to compile one on the fly using CudaNative). KnetArrays only have cuda kernels for primitive operations. When a user defined function is broadcasted, the right primitive kernels have to be called in the right order to get the right result.

If we define a KnetStyle and we have a user-defined function f(x)=cos(sin(x)), the broadcast expression f.(x) is going to give me Broadcasted{KnetStyle}(f,x) is it not? How do I get Broadcasted(cos, Broadcasted(sin, x)) ?

Ah, then perhaps an intermediate solution wherein you define your KnetStyle and then you can have complete control over the broadcasted(::KnetStyle, …) method sub-tree. You could define, for example, a fallback broadcasted(::KnetStyle, f, args...) = ... without fear of ambiguities in that manner. What goes on the RHS is an interesting question as just doing f.(args...) will be infinitely recursive… but perhaps copy(Broadcasted(f, args)) would do the trick.

You would then also need to implement copy(::Broadcasted{KnetStyle}), but if I remember right we’ve implemented convert(Broadcasted{DefaultArrayStyle{N}}, bc) so you can use that and re-dispatch to get back to the default implementation.

So if my understanding is correct, styles do not help me get inside f(x)=cos(sin(x)), i.e. broadcasted(::KnetStyle, f, args...) even after copy, convert etc would not magically give me Broadcasted(cos, Broadcasted(sin, x)). I would still need to (1) box the args with a user-defined type, (2) call f on the boxed args, (3) define primitive methods that handle the boxed args as I did in 0.6. It is just that some of these steps could be implemented in a style specific manner.