How to customize the new broadcasting infrastructure in v0.7


#1

I want to ask about the new broadcasting operation in Julia 0.7 and the preferred way to customize it for custom types. I read the following:

https://docs.julialang.org/en/latest/manual/interfaces/#man-interfaces-broadcasting-1

I am a bit confused by the various options to stop fusion: do I override make, copy, define a new style for my custom type etc? What is the right way to define broadcasting methods for primitive functions like +, *, sin etc. specific to these types?


#2

Please see https://github.com/MikeInnes/TakingBroadcastSeriously.jl/issues/8

In particular, I’ve successfully stopped fusion for my TrackedArrays which should be pretty close to what you do in AutoGrad/Knet.


#3

That doesn’t seem complete; presumably you want to take advantage of BroadcastStyles somewhere, to avoid games catching things like broadcast(+, ::Any, ::Any, ::TrackedArray).


#4

Yes, there are a whole slew of new features here, and they aren’t totally orthogonal. I think it’s helpful to think of this in stages. Assume in all the code samples below that I’m using .Broadcast: broadcasted, Broadcasted.

Construction of the broadcast expression tree

The parser transforms an expression like (A .+ B) ./ C into broadcasted(/, broadcasted(+, A, B), C). By default, this constructs Broadcasted objects that just hold onto the function and tuple of arguments, but it’s using a lowercase function (instead of the constructor directly) to allow you to do something different. This means that, as dfdx suggested, you can simply overload broadcasted(::typeof(+), ::YourType, ::YourType) to either return an alternative array-like lazy representation or do the work immediately and return an intermediate array.

Sometimes, like in the case Mike mentions, you also have a BroadcastStyle promotion system setup and want to dispatch on the combined style of the arguments — often to catch that "at least one argument is a TrackedArray" dispatch problem. You can instead overload broadcasted(::YourStyle, ::typeof(+), ...). Note, though that this is just the style of the arguments passed to +, not the overall style of the entire fused expression. This is how ranges can now return ranges from broadcasting: they opt-out of fusion to compute things in O(1) when possible: base/broadcast.jl#L961-L1013.

Note that one of the arguments you get in such a broadcasted implementation could be a lazy Broadcasted expression, too! For example if C was a TrackedArray and A and B weren’t, you’d end up with a division between a Broadcasted(/, A, B) and your C. You’ll have to decide if you want to manually materialize that nested Broadcasted into a temporary before executing the division by C or if you’re able to fuse it with the division.

Execution of a broadcasted expression

If you don’t override broadcasted, Julia will create a Broadcasted representation of the expression. It computes the overall broadcast style of the entire expression by walking through all the nested Broadcasted nodes and combining them, and then it stores this as the first type parameter of the Broadcasted object. This is then copy'ed or copyto!'ed (in the case of .=), allowing you to customize copy(bc::Broadcasted{YourStyle}) or copyto!(::AbstractArray, bc::Broadcasted{YourStyle}). In these function bodies, you have the entire Broadcasted expression tree available for introspection — you can walk through the bc.args and manually decide how you want to execute the functions. Of course, that can be a lot of work and hard to get right, so you may just want to inspect the broadcast tree to see if it’s an “easy” case (like just a map-equivalent without broadcast expansion) and defer to a simpler optimized implementation. Or you can walk through the passed bc object and transform it into an equivalent Broadcasted representation if it’s at all possible as that will allow you to use its simple outer API:

for I in eachindex(bc)
    result[I] = bc[I]
end

This is how BitArray hooks into the broadcast system — it uses both aspects I mention here: it reports itself as a DefaultArrayStyle, and then in the DefaultArrayStyle implementation, we first walk through the Broadcasted expression tree to see if it’s a case where we can perform a “chunked” broadcast — that is, if we can use an implementation that operates on the UInt64 chunks instead of the individual bits (base/broadcast.jl#L841). The implementation then also does a transformation of the passed bc object to convert compatible bit-wise functions into their “chunkable” equivalents (base/broadcast.jl#L880-L886). For example, we transform the ! function (which only negates bools) to the bitwise ~ function to allow inversions to operate at the level of chunks.


So that’s a lot of detail, but I hope that helps you figure out where is best for you to latch into the system. The simple answer is broadcasted, but there may be cases where you’d want to consider the entire expression tree as a whole before doing your fusion opt-out.


Why Julia returns error on myit .+ 1 in the following code?
#5

Here is what I came up with in AutoGrad, please comment if you see anything obviously wrong or suboptimal:

  • Rec is my “tracked object” type, with the actual object stored in Rec.value.

  • For most functions (like abs) with a gradient, I want to compute the result right away and record it if one of my inputs is a Rec type, so broadcasted basically calls broadcast after stripping the Rec. broadcast_r below is the recording version of broadcast, but for this discussion assume it just strips the Rec and calls broadcast, i.e. broadcast_r(f,x) is equivalent to broadcast(f, x.value):

broadcast_r = recorder(broadcast)

broadcasted(f, x::Rec) = broadcast_r(f,x)

broadcasted(f, x::Rec, y) = broadcast_r(f,x,y)

broadcasted(f, x, y::Rec) = broadcast_r(f,x,y)

broadcasted(f, x::Rec, y::Rec) = broadcast_r(f,x,y)

  • This handles up to two arguments, I’ll have to handle functions with more arguments separately.

  • The alternative was to define specific broadcasted methods for each function:

broadcasted(::typeof(abs), x::Rec) = broadcast_r(abs,x)

But I assumed defining hundreds of “broadcasted” methods would be costly? If it is not costly, it would be cleaner.

  • For simple functions (like sign) with no gradient, I want to ignore the Rec type and do what is normally done, so I defined exceptions to the above generic broadcasted(f,…) methods:

broadcasted(f::typeof(sign), x::Rec) = broadcasted(f, x.value)


#6

Yes, that seems reasonable, although note that you’re discarding broadcast fusion by immediately calling broadcast(f, x.value). If instead broadcast_r returns a Broadcasted(f, (x.value,)), then this will layer atop broadcast fusion (and others’ specializations thereof) more gracefully.

Another way you can more easily handle the dispatch on broadcasted with at least one Rec argument is to define a BroadcastStyle for Recs, and then instead dispatch on broadcasted(::Style{Rec}, f, args...). You need to set up a style “promotion” such that Style{Rec} always wins — or would throw an ambiguity/promotion error if it encounters another such “always wins” style. I think something like this should do the trick (untested):

Broadcast.BroadcastStyle(::Rec) = Broadcast.Style{Rec}()
Broadcast.BroadcastStyle(::Broadcast.BroadcastStyle, s::Broadcast.Style{Rec}) = s

#7

I tried @mbauman’s first suggestion: keeping the results lazy:

broadcast_r = recorder(broadcasted)
broadcasted(f, x::Rec) = broadcast_r(f,x)

This way the boxed values stay as lazy Broadcasted objects longer. However when I try to take a gradient involving sum(x.*x,dims=1) I now seem to get errors like:

ERROR: MethodError: no method matching 
sum(::Base.Broadcast.Broadcasted; dims=1)

I don’t get the error without the dims argument.

Can you help me understand how/when do the Broadcasted objects materialize in the two cases of sum(x.*x) vs sum(x.*x,dims=1) and why the two cases may be different?


#8

To partially answer my own question: Base defines materialize(x)=x for x types other than Broadcasted. This prevents arrays from materializing if they are boxed. I solved this by overriding materialize in AutoGrad. Now trying to figure out why sum(x.*x) works without this.