Broadcasting in 0.7

Hi,

I am upgrading a package of mine to v0.7. So far the trickiest bit is updating the broadcasting code to the new broadcasting facilities of v0.7. I have read the relevant manual pages, plus other bits here and there, but still have some questions.

The idea is that I have a type, called AugmentedState that wraps two AbstractArrays, x and q, i.e.

struct AugmentedState{X, Q}
    x::X
    q::Q
end

This type is part of the internals of the package and there is no risk that objects of this type collide with other objects. The usage case is that objects of this type are used for in-place operations, exclusively involving other AugmentedState objects and optionally Numbers, as in a .= b .+ 2.0*c.

Objects of this type do not support indexing: the only aim is to have arithmetic operations such as this one forwarded equally to the two wrapped arrays. This is what I have so far

# extract parts
@inline _state(x::AugmentedState) = x.x
@inline _quad(x::AugmentedState) = x.q
@inline _state(x::Base.RefValue) = _state(getindex(x))
@inline _quad(x::Base.RefValue) = _quad(getindex(x))
@inline _state(x) = x
@inline _quad(x) = x

@inline function Base.Broadcast.materialize!(dest::AugmentedState, bc::Base.Broadcast.Broadcasted)
    bcf = Base.Broadcast.flatten(bc)
    return __broadcast(bcf.f, dest, bcf.args...)
end
    
@generated function __broadcast(f, dest, args...)
    quote
        $(Expr(:meta, :inline))
        Base.Broadcast.broadcast!(f,
                                 _state(dest),
                                  map(_state, args)...)
        Base.Broadcast.broadcast!(f,
                                 _quad(dest),
                                  map(_quad, args)...)
        return dest
    end
end

This seems to work, but it incurs in additional allocations, plus a warning. Demo:

a = AugmentedState([1, 2], [3])
b = AugmentedState([1, 2], [3])
c = AugmentedState([0, 0], [0])

foo(a, b, c) = (c .= a .+ b .+ 2.0.*a; c)

@time foo(a, b, c)
@time foo(a, b, c)
@time foo(a, b, c)
@time foo(a, b, c)

gives

ā”Œ Warning: broadcast will default to iterating over its arguments in the future. Wrap arguments of
ā”‚ type `x::Flows.AugmentedState{Array{Int64,1},Array{Int64,1}}` with `Ref(x)` to ensure they broadcast as "scalar" elements.
ā”‚   caller = ip:0x0
ā”” @ Core :-1
  1.569558 seconds (3.24 M allocations: 156.342 MiB, 9.89% gc time)
  0.001696 seconds (852 allocations: 61.859 KiB)
  0.001987 seconds (852 allocations: 61.859 KiB)
  0.001628 seconds (852 allocations: 61.859 KiB)

In v0.6, I achieved this with this code

@generated function Base.Broadcast.broadcast!(f, dest::AugmentedState, args::Vararg{Any})
    quote
        $(Expr(:meta, :inline))
        broadcast!(f, _state(dest), map(_state, args)...)
        broadcast!(f,  _quad(dest), map(_quad,  args)...)
        return dest
    end
end    

and had no extra allocations.

Any help is appreciated!

1 Like

This looks to be the same issue as:

https://github.com/JuliaLang/julia/issues/27988

I get the utility here, but I donā€™t think broadcasting really makes sense for an object that doesnā€™t have a shape or indexing.

Thanks, I had seen this one.

I can see why you think this might be an abuse of broadcasting, but as you say it serves the purpose. In any case, I do not see any other way to achieve what I want at the moment. In C++ one would probably use expression templates, in Julia all that is needed is sprinkling few dots here and there plus one little function to do the job.

One question. Does the allocation come from the splatting?

There are two things going on here:

  1. Deprecation warnings are expensive. Define Broadcast.broadcastable(a::AugmentedState) = a to avoid it.
  2. Base.Broadcast.flatten(bc) is type-unstable due to the trouble inference is having recursing through the nested Broadcasted object. Type-instabilities cause allocations.
2 Likes

I have come up with this approach, which avoids allocations when the broadcasted expression is simple enough not to trigger #27988

import Base.Broadcast: materialize!,
                       flatten,
                       broadcast!,
                       Broadcasted,
                       broadcastable

using BenchmarkTools

struct AugmentedState{X, Q}
    x::X
    q::Q
end

# extract parts from augmented state object
@inline _state(x::AugmentedState) = x.x
@inline _quad(x::AugmentedState) = x.q
@inline _state(x) = x
@inline _quad(x) = x

broadcastable(x::AugmentedState) = x
Base.ndims(::Type{<:AugmentedState}) = 1

# Arithmetic operations are forwarded to both components
@inline function materialize!(dest::AugmentedState, bc::Broadcasted)
    bcf = flatten(bc)
    broadcast!(bcf.f, _state(dest), map(_state, bcf.args)...)
    broadcast!(bcf.f, _quad(dest),  map(_quad,  bcf.args)...)
    return dest
end

a = AugmentedState(randn(10^3), randn(10^2))
b = AugmentedState(randn(10^3), randn(10^2))
c = AugmentedState(randn(10^3), randn(10^2))
d = AugmentedState(randn(10^3), randn(10^2))

foo(a::A, b::A, c::A) where {A<:AugmentedState} =
    (c .= a .+ b ; c)
foo_b(a::A, b::A, c::A) where {A<:AugmentedState} =
    (c.x .= a.x .+ b.x; 
     c.q .= a.q .+ b.q; c)

# allocations -> see #27988
bar(a::A, b::A, c::A) where {A<:AugmentedState} =
    (c .= a .+ b .+ (1 .+ a); c)
bar_b(a::A, b::A, c::A) where {A<:AugmentedState} =
    (c.x .= a.x .+ b.x .+ (1 .+ a.x); 
     c.q .= a.q .+ b.q .+ (1 .+ a.q); c)

# checks
foo(a, b, d)
foo_b(a, b, c)
@assert all(d.x == c.x) == true
@assert all(d.q == c.q) == true
bar(a, b, d)
bar_b(a, b, c)
@assert all(d.x == c.x) == true
@assert all(d.q == c.q) == true

# no allocations
@btime foo($a, $b, $c)
@btime foo_b($a, $b, $c)

# allocations -> see #27988
@btime bar($a, $b, $c)
@btime bar_b($a, $b, $c)

which gives:

  150.910 ns (0 allocations: 0 bytes)
  151.039 ns (0 allocations: 0 bytes)
  310.037 ns (11 allocations: 256 bytes)
  266.752 ns (0 allocations: 0 bytes)

Note how the nested broadcasted expression results in allocations. @mbauman, does your comment implies that that issue can be solved?

Yes, exactly. Complicated expressions currently make Broadcast.flatten type-unstable. I sure thought that Iā€™d be able fix the instability, but there seem to be other things going on there. Weā€™ll get it at some point.

That said, the (ab)use of broadcasting like this still makes me queasy. It seems like a macro would work better here and have a similar syntactic ā€œweightā€.

1 Like

The issue is that broadcasting expression are used in other generic code that has to work regardless of whether its arguments are arrays or AugmentedState objects. A macro would not be able to distinguish between the two types, whereas using and intercepting the broadcasting facilities can.

My use case is numerical integration of differential equations, where the broadcasting expressions are basically update steps in, e.g., a Range-Kutta method. In many cases, one is interested in the solution of a differential equation, say \mathbf{x}(t). In some other cases, one is interested in some integral quantity like \int_0^T g(\mathbf{x}(t))dt, for some function g(\mathbf{x}(t)). A common approach to estimate such integrals with the same accuracy of the main numerical scheme is to augment the system with a quadrature equation, an additional differential equation \dot{z}(t) = g(\mathbf{x}(t)) that is integrated alongside the main equations from the initial condition z(0) = 0. Integration of this augmented system up to time t=T, results in z(T) = \int_0^T g(\mathbf{x}(t))dt, the desired quantity.

Being able to broadcast the update steps on both the state and ā€œquadratureā€ parts is essential to have generic code, i.e. numerical methods for differential equations that work transparently, whether we have a quadrature equation or not. For reference, the code is here.

I would be interested in @ChrisRackauckas opinion on this. As far as I know there is no such facility in the DifferentialEquations.jl ecosystem (or is there?). I know SUNDIALS uses a similar approach.

In any case, I agree with you that this might be pushing the dot notation too far. However, I find it a very elegant approach, leading to fast, generic and short code.

Iā€™d be 100% behind using broadcasting here if we could make AugmentedStates actually be array-like ā€” that is, supporting indexing and axes.

1 Like

I have explored that design space previously, and found it much less elegant and generic than (ab)using broadcasting. The reason is that it forces me to consider the details of the fields x::X and q::Q, (shape, length, indexing, bound checking, ā€¦) and treat them as AbstractVector while in this case I just delegate all such operations to the types X and Q separately.

Interesting, since all of the native Julia codes in DifferentialEquations.jl are built around abstract typing. I would say Sundials is highly limited in comparison to what weā€™re doing. This is discussed in the blog post:

Like the Sundials NVector approach, the native Julia solvers in OrdinaryDiffEq.jl/StochasticDiffEq.jl/DelayDiffEq.jl allow for state types which are broadcasting (there was a caveat in v0.6 about the length of statements, but this is fixed in v0.7). As a result, these libraries are built around broadcastable-types which donā€™t necessarily have (good) indexing. GPUArrays are a type which broadcasts but doesnā€™t index (well, thereā€™s a slow fallback but you can even turn that off).

Normally the approach is to make the states array-like as @mbauman suggests. This is so that way buffer arrays can be used. There are many practical examples of this. When using a solver for stiff equations, you will need to solve linear equations. If you can vec_tmp .= u the state u into a normal vector, then this can easily be used with BLAS via *. So itā€™s much easier to do these kind of ā€œto array and backā€ workflows if the broadcast can work arrays in there. Indeed, there is a lot of work being done by @dlfivefifty and @sacha to make multi-strided Julia array types able to use BLAS effectively (and in the latter case, a native Julia BLAS). But I still find that for things like LAPACK routines (\ after a factorization) you still need an array so this finds many uses.

RecursiveArrayTools.jl has an example of array which has been updated. The ArrayPartition is a ā€œvectorā€ made via the concatenation of the linear indexing of arrays in a tuple. Being housed in a tuple, these arrays can be heterogeneous. But the broadcast is defined recursively

so in an update statement, like as seen in the integrators like our RK4

will go through and broadcast only u.x[i] with other u.x[i], so even if the different arrays have different types in them this broadcast is still fast. One use case for this is units. The linear index of the ArrayPartition is slow if you have arrays with different units because indexing is not type stable, but this broadcasting operation is type stable. Interestingly, because of how the cache is used, weā€™ve noticed that even ArrayPartitions with only Float64s can be faster on broadcasted operations than standard Float64 arraysā€¦ so thatā€™s a great showcase of broadcast used in this manner.

Iā€™d like to see where your AugmentedState goes though. We have something similar in DiffEqJump.jl which is used to implement variable-rate Poisson jumps onto continuous problems. To do so, you solve a rootfinding problem where the jump occurs when a value hits zero, so we take the userā€™s ODE state and pack it into an augmented vector with pieces on the end for calculating the extra jump coordinates and then just throw the whole thing into the ODE solver. Itā€™s defined by an integral so this is exactly like your AugmentedState case.

This still needs to be updated for v0.7, but you can see how this just chains indices in order to broadcast it like a big concatenated vector.

I always put a linear index on it too as a fallback, but just like the ArrayPartition it should be avoided. (Though one question we donā€™t have answered well with these is implementation hiding. Because we save the state during the solution, we end up saving the full jump variables as well and so the user doesnā€™t get back the same array type they put in. For your case that would be fine though since it sounds like g((x(t)) is something the user actually wants).

One extreme case of this is MultiScaleArrays.jl

These are augmented arrays to the extreme, where you put a tree structure on them. But via tree traversal you can still define a linear index. That linear index is implemented with recursive binary searches. The broadcast is a much more sensible recursive broadcast:

In fact, this structure builds the most complicated types weā€™ve successfully solved differential equations on, including an almost thousand line type definition (from this thread). Again, using tuples for units plus this broadcast overload allows it to perform just fine even though itā€™s probably ā€œthe most type-unstable linear index imaginableā€. But, it the linear shape makes it still .= over to a vector for stiff solvers, makes it easily interact with tools for printing vectors, etc.

So I think the summary of what I am saying is, @mbauman is right in that if you can just throw a linear index through the whole thing then you should because it is a useful structure. Broadcast doesnā€™t need to make use of it, and many times itā€™s faster to avoid it and maybe the ordering can be almost nonsensical, but itā€™s an easy fix to a lot of things like Juliaā€™s basic printing capabilities. In some sense, if you treat these things as all ā€œvectorsā€, then broadcast between them all just means ā€œelement-wiseā€ and using the broadcast mechanism you can define the element-wise implementation differently depending on what you need. I wouldnā€™t count this as abuse, I think this is a super helpful use of the broadcast system!

4 Likes

Thanks. The examples you bring here are actually excellent examples. As far as I can see from them, you often define the AbstractVector interface for these types, but then rely on broadcasting for performance reasons. I could also define this interface for my ā€œAugmentedStateā€ type, but in practice I do not need it because itā€™s part of the internals of the package, and I do not need exposing this interface to the outer world.

1 Like

BTW, since you rely on broadcasting a lot, have you not hit #27988 so far?

I havenā€™t tested enough on v0.7 yet. Weā€™re still about a few days away from having a full stack ready to benchmark in v0.7.

1 Like

Just to be clear ā€” the only part that I find questionable is making these objects kinda/sorta broadcastable. One of the major impetuses for the broadcast revamp was that you can now just ask something for its broadcastable representation and then know precisely how itā€™ll behave in broadcast by asking for its axes and such. Making something broadcastable without actually supporting axes/indexing puts it in a strange limbo and may cause trouble if it ends up getting passed to something that wants to work alongside broadcast. Plus Iā€™d just find it confusing if I ran across it in the wild. :slight_smile:

3 Likes

I just want to underline this. Please donā€™t do this. Broadcasting is all about having axes and indices. If you just want to ā€œabuseā€ the broadcast notation, donā€™t. Thereā€™s plenty of other notation available.

3 Likes

Hi Stefan. I would be definitely inclined towards using other notation, but I canā€™t see at the moment what that notation would be. I have not played enough with macros on this front, though. Perhaps a macro that translates code like

@statequad c = a + b

to broadcasted expressions such as

_state(c) .= _state(a) .+ _state(b);
_quad(c) .= _quad(a) .+ _quad(b); 

Anyway, from the examples that Chris has shown, this kind of ā€œabuseā€ seems useful in different ways and it is somehow common too (although not widespread, I agree). Other people might feel a similar need. I think it is still beneficial to discuss this aspect and maybe come up with proper Julian ways to handle these cases.

Is the main issue the ability to hook into the = syntax?

1 Like

I was hoping someone with better macro skills to suggest possible ways forward. I am not good enough with macros to understand if this is even technically possible. Hooking into the broadcasting infrastructure was something that was within my reach and very compact to write.

Maybe Cassette overdubbing is another way to do this well. In theory you can take all element-wise loops and overdub them to be the version for the type. But broadcast is a very simple way to get it done so Iā€™m not sure the enlarged complexity of Cassette should actually be used here.

As you say, I would not want to depend on a separate large package when ten lines of code suffice. Hooking into the broadcasting infrastructure is so simple, compared to e.g. C++ expression templates. However, I understand that this is probably not the intended usage of broadcasting and there are issues that come with it. The point is how to mode forward.