Broadcasting setindex!

Let’s define the following composite type of an array and a counter. The counter is supposed to keep track of the number of changes made to the array.

import Base: getindex, setindex!, size, length

struct ArrayWithCounter
    counter::Ref{Int64}
    data::Array{Int64, 2}
end
ArrayWithCounter(n) = ArrayWithCounter(0, fill(0, n,n))

size(A::ArrayWithCounter, i...) = size(A.data,i...)
length(A::ArrayWithCounter, i...) = length(A.data,i...)

getindex(A::ArrayWithCounter, i...) = A.data[i...]

function setindex!(A::ArrayWithCounter, v, i...)
    A.counter[] += 1
    setindex!(A.data, v, i...)
end

A = ArrayWithCounter(5)

While A[5,5] = 1 works as intended, broadcasted setting A[:,:] .= 1 does not.
The last call neither sets all elements of A.data to 1 nor does it advance the counter.

This is both expected. The failure to set all elements of data comes about because A[:,:] returns a copy, not a view into the array. This problem could be remedied by defining
Base.maybeview(A::ArrayWithCounter, i...) = maybeview(A.data, i...). The problem runs deeper though, because as far as I can tell, A[:,:] .= 1 is equivalent to broadcast!(identity, A.data[:,:], 1) with no reference to the original ArrayWithCounter.
edit that statement wasn’t correct as pointed out below.

Question: How - if possible - to implement broadcasting correctly such that it modifies data in place and advances counter correctly?

Note that this is a contrived example. I have a more complicated structure in mind, but the approach should be transferable.

Thanks for your input!

That’s not true, you can see the lowered code for A[:,:] .= 1 and it’s more or less

julia> @code_lowered f(a)
CodeInfo(
1 ─ %1 = Base.dotview(x, Main.:(:), Main.:(:))
│   %2 = Base.broadcasted(Base.identity, 1)
│   %3 = Base.materialize!(%1, %2)
└──      return %3
)

You could make ArrayWithCounter a subtype of AbstractArray which would make things easier but if you can’t you most likely need to make something like your own ArrayWithCounterView, overload dotview, maybeview or even just view so that viewing ArrayWithCounter returns ArrayWithCounterView so that the counter isn’t lost.

3 Likes

Thanks for your reply. Having dotview wrap the object into an ArrayWithCounterView and defining an appropriate materialize! or copyto! seems the way to go. I’ll experiment a bit :slight_smile:

1 Like

After some experimentation, I have come up with a solution that satisfies me and which I deem extendable.

The problem splits into four parts:

  1. Broadcasting over ArrayWithCounter
  2. View into ArrayWithCounter, i.e. ArrayWithCounterView
  3. Broadcasting over ArrayWithCounterView
  4. Interoperability of both types.

This is what I came up with

import Base: axes, broadcastable, copyto!, eltype, getindex, maybeview, setindex!, similar, size, length, show, view
import Base.Broadcast

######################
## ArrayWithCounter ##

struct ArrayWithCounter
    counter::Ref{Int64}
    data::Array{Float64, 2}
end
ArrayWithCounter(n) = ArrayWithCounter(0, fill(0.0, n,n))

eltype(A::ArrayWithCounter) = eltype(A.data)
size(A::ArrayWithCounter, i...) = size(A.data,i...)
length(A::ArrayWithCounter, i...) = length(A.data,i...)
axes(A::ArrayWithCounter) = axes(A.data)
similar(A::ArrayWithCounter) = ArrayWithCounter(A.counter[], similar(A.data))
getindex(A::ArrayWithCounter, i...) = A.data[i...]
setindex!(A::ArrayWithCounter, v, i::CartesianIndex{2}) = setindex!(A, v, Tuple(i)...)
function setindex!(A::ArrayWithCounter, v, i::Vararg{Int64})
    if A.data[i] == v
        return v
    end
    A.counter[] += 1
    setindex!(A.data, v, i...)
end

broadcastable(x::ArrayWithCounter) = x

Base.BroadcastStyle(::Type{ArrayWithCounter}) = Broadcast.Style{ArrayWithCounter}()
# Custom broadcast style should take precedence over ArrayStyle to enable custom `copyto!` method below.
Base.BroadcastStyle(::Broadcast.Style{ArrayWithCounter}, ::Broadcast.AbstractArrayStyle) = Broadcast.Style{ArrayWithCounter}()

function similar(bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}, ::Type{ElType}) where ElType
    A = find_awc(bc)
    ArrayWithCounter(A.counter[], similar(A.data, ElType)) # keep the old counter
end

"""`A = find_awc(As)` returns the first ArrayWithCounter among the arguments of a Broadcasted object
Lifted from the Julia documentation
https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting
"""
find_awc(bc::Base.Broadcast.Broadcasted) = find_awc(bc.args)
find_awc(args::Tuple) = find_awc(find_awc(args[1]), Base.tail(args))
find_awc(x) = x
find_awc(::Tuple{}) = nothing
find_awc(a::ArrayWithCounter, rest) = a
find_awc(::Any, rest) = find_awc(rest)

copyto!(dest::ArrayWithCounter, bc::Broadcast.Broadcasted{T}) where {T<:Broadcast.AbstractArrayStyle} = _copyto!(dest, bc)
copyto!(dest::ArrayWithCounter, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}) = _copyto!(dest, bc)
function _copyto!(dest::ArrayWithCounter, bc)
    old_data = copy(dest.data)
    copyto!(dest.data, unpack(bc))
    changed = count(i->old_data[i]!=dest.data[i], eachindex(dest.data))
    dest.counter[] += changed
    dest
end

function copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}})
    copyto!(dest, unpack(bc))
    dest
end

"""Recursively replace occurences of `A::ArrayWithCounter` with `A::data` in any (nested)
`Broadcasted` object in order defer `copyto!` to the existing methods for arrays.
Copied from https://discourse.julialang.org/t/only-specializing-broadcast-on-likewise-types/20117
"""
@inline unpack(bc::Broadcast.Broadcasted) = Broadcast.Broadcasted(bc.f, unpack_args(bc.args))
unpack(x) = x
unpack(x,::Any) = x
unpack(x::ArrayWithCounter) = x.data
@inline unpack_args(::Tuple{}) = ()
@inline unpack_args(args::Tuple) = (unpack(args[1]), unpack_args(Base.tail(args))...)
unpack_args(args::Tuple{Any}) = (unpack(args[1]),)
unpack_args(::Any, args::Tuple{}) = ()

##########################
## ArrayWithCounterView ##

struct ArrayWithCounterView
    parent::ArrayWithCounter
    view::SubArray
end

view(A::ArrayWithCounter, i...) = ArrayWithCounterView(A, view(A.data, i...))
maybeview(A::ArrayWithCounter, args...) = view(A, args...)

size(A::ArrayWithCounterView, i...) = size(A.view, i...)
length(A::ArrayWithCounterView) = length(A.view)
eltype(A::ArrayWithCounterView) = eltype(A.parent)

similar(A::ArrayWithCounterView) = similar(A.parent)
function similar(bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}}, T::Type{ElType}) where ElType
    A = find_awcv(bc)
    similar(A.parent)
end
find_awcv(bc::Base.Broadcast.Broadcasted) = find_awcv(bc.args)
find_awcv(args::Tuple) = find_awcv(find_awcv(args[1]), Base.tail(args))
find_awcv(x) = x
find_awcv(::Tuple{}) = nothing
find_awcv(::Any, rest) = find_awcv(rest)
find_awcv(a::ArrayWithCounterView, rest) = a


broadcastable(A::ArrayWithCounterView) = A

Base.BroadcastStyle(::Type{ArrayWithCounterView}) = Broadcast.Style{ArrayWithCounterView}()
Base.BroadcastStyle(::Broadcast.Style{ArrayWithCounterView}, ::Broadcast.AbstractArrayStyle) = Broadcast.Style{ArrayWithCounterView}()
Base.BroadcastStyle(::Broadcast.Style{ArrayWithCounterView}, ::Broadcast.Style{ArrayWithCounter}) = Broadcast.Style{ArrayWithCounter}()

function copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}})
    copyto!(dest, unpack(bc))
end
unpack(x::ArrayWithCounterView) = x.view

function copyto!(dest::ArrayWithCounter, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}})
    copyto!(dest, unpack(bc))
end

function copyto!(dest::ArrayWithCounterView, bc::Union{Broadcast.Broadcasted{T}, Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}},Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}}) where {T<:Broadcast.AbstractArrayStyle}
    _copyto!(dest, bc)
end
function _copyto!(dest, bc)
    old_data = copy(dest.view)
    copyto!(dest.view, bc)
    changed = count(i->old_data[i]!=dest.view[i], eachindex(dest.view))
    dest.parent.counter[] += changed
    dest
end

## pretty printing ##

function show(io::IO, ::MIME"text/plain", A::ArrayWithCounter)
    print(io, "$(typeof(A)) with $(A.counter[]) changes recorded.\n")
    show(io, MIME("text/plain"), A.data)
end

function show(io::IO, ::MIME"text/plain", A::ArrayWithCounterView)
    print(io, "view into an $(typeof(A.parent)) with $(A.parent.counter[]) changes recorded.\n")
    show(io, MIME("text/plain"), A.view)
end


###########
## Tests ##

using Test

A = ArrayWithCounter(5)
B = ArrayWithCounter(5)
B2 = ArrayWithCounter(2)
A .+= B
@test A.counter[] == 0
A[1:2, 1:2] .+= cos.(view(B2,:,:))
B3 = cos.(B2)
@test sum(A[1:2,1:2]) == sum(B3.data) && A.counter[] == count(x->x!=0, B3.data)
C = deepcopy(A)
rbc = Broadcast.Broadcasted(randn,())
A .+= rbc .+ C.*B
@test A.counter[] == length(A)+length(B2)
@test typeof(view(B,:,:) .+ 0) == ArrayWithCounter

It has become rather extensive, but you can mix and match basically anything in broadcasts.

If there are suggestions for improvement, I’d be happy to hear them. In particular, I find the whole unpack and find_x story a little cumbersome. Mostly because it is boiler plate that will appear again and again in any custom broadcasting.

1 Like

That looks quite well. I know how to get rid of find_x – you just need to implement similar for broadcasted like that:

function Base.similar(::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}, ::Type{ElType}, dims) where {N,ElType}
    return ArrayWithCounter(0, similar(Array{ElType}, dims))
end

Full code:

import Base: axes, broadcastable, copyto!, eltype, getindex, maybeview, setindex!, similar, size, length, show, view
import Base.Broadcast

######################
## ArrayWithCounter ##

struct ArrayWithCounter
    counter::Ref{Int64}
    data::Array{Float64, 2}
end
ArrayWithCounter(n) = ArrayWithCounter(0, fill(0.0, n,n))

eltype(A::ArrayWithCounter) = eltype(A.data)
size(A::ArrayWithCounter, i...) = size(A.data,i...)
length(A::ArrayWithCounter, i...) = length(A.data,i...)
axes(A::ArrayWithCounter) = axes(A.data)
similar(A::ArrayWithCounter) = ArrayWithCounter(A.counter[], similar(A.data))
getindex(A::ArrayWithCounter, i...) = A.data[i...]
setindex!(A::ArrayWithCounter, v, i::CartesianIndex{2}) = setindex!(A, v, Tuple(i)...)
function setindex!(A::ArrayWithCounter, v, i::Vararg{Int64})
    if A.data[i] == v
        return v
    end
    A.counter[] += 1
    setindex!(A.data, v, i...)
end

broadcastable(x::ArrayWithCounter) = x

Base.BroadcastStyle(::Type{ArrayWithCounter}) = Broadcast.Style{ArrayWithCounter}()
# Custom broadcast style should take precedence over ArrayStyle to enable custom `copyto!` method below.
Base.BroadcastStyle(::Broadcast.Style{ArrayWithCounter}, ::Broadcast.AbstractArrayStyle) = Broadcast.Style{ArrayWithCounter}()

function Base.similar(::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}, ::Type{ElType}, dims) where {N,ElType}
    return ArrayWithCounter(0, similar(Array{ElType}, dims))
end

copyto!(dest::ArrayWithCounter, bc::Broadcast.Broadcasted{T}) where {T<:Broadcast.AbstractArrayStyle} = _copyto!(dest, bc)
copyto!(dest::ArrayWithCounter, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}) = _copyto!(dest, bc)
function _copyto!(dest::ArrayWithCounter, bc)
    old_data = copy(dest.data)
    copyto!(dest.data, unpack(bc))
    changed = count(i->old_data[i]!=dest.data[i], eachindex(dest.data))
    dest.counter[] += changed
    dest
end

function copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}})
    copyto!(dest, unpack(bc))
    dest
end

"""Recursively replace occurences of `A::ArrayWithCounter` with `A::data` in any (nested)
`Broadcasted` object in order defer `copyto!` to the existing methods for arrays.
Copied from https://discourse.julialang.org/t/only-specializing-broadcast-on-likewise-types/20117
"""
@inline unpack(bc::Broadcast.Broadcasted) = Broadcast.Broadcasted(bc.f, unpack_args(bc.args))
unpack(x) = x
unpack(x,::Any) = x
unpack(x::ArrayWithCounter) = x.data
@inline unpack_args(::Tuple{}) = ()
@inline unpack_args(args::Tuple) = (unpack(args[1]), unpack_args(Base.tail(args))...)
unpack_args(args::Tuple{Any}) = (unpack(args[1]),)
unpack_args(::Any, args::Tuple{}) = ()

##########################
## ArrayWithCounterView ##

struct ArrayWithCounterView
    parent::ArrayWithCounter
    view::SubArray
end

view(A::ArrayWithCounter, i...) = ArrayWithCounterView(A, view(A.data, i...))
maybeview(A::ArrayWithCounter, args...) = view(A, args...)

size(A::ArrayWithCounterView, i...) = size(A.view, i...)
length(A::ArrayWithCounterView) = length(A.view)
eltype(A::ArrayWithCounterView) = eltype(A.parent)

function Base.similar(::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}}, ::Type{ElType}, dims) where {N,ElType}
    return ArrayWithCounter(0, similar(Array{ElType}, dims))
end

broadcastable(A::ArrayWithCounterView) = A

Base.BroadcastStyle(::Type{ArrayWithCounterView}) = Broadcast.Style{ArrayWithCounterView}()
Base.BroadcastStyle(::Broadcast.Style{ArrayWithCounterView}, ::Broadcast.AbstractArrayStyle) = Broadcast.Style{ArrayWithCounterView}()
Base.BroadcastStyle(::Broadcast.Style{ArrayWithCounterView}, ::Broadcast.Style{ArrayWithCounter}) = Broadcast.Style{ArrayWithCounter}()

function copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}})
    copyto!(dest, unpack(bc))
end
unpack(x::ArrayWithCounterView) = x.view

function copyto!(dest::ArrayWithCounter, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}})
    copyto!(dest, unpack(bc))
end

function copyto!(dest::ArrayWithCounterView, bc::Union{Broadcast.Broadcasted{T}, Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}},Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}}) where {T<:Broadcast.AbstractArrayStyle}
    _copyto!(dest, bc)
end
function _copyto!(dest, bc)
    old_data = copy(dest.view)
    copyto!(dest.view, bc)
    changed = count(i->old_data[i]!=dest.view[i], eachindex(dest.view))
    dest.parent.counter[] += changed
    dest
end

## pretty printing ##

function show(io::IO, ::MIME"text/plain", A::ArrayWithCounter)
    print(io, "$(typeof(A)) with $(A.counter[]) changes recorded.\n")
    show(io, MIME("text/plain"), A.data)
end

function show(io::IO, ::MIME"text/plain", A::ArrayWithCounterView)
    print(io, "view into an $(typeof(A.parent)) with $(A.parent.counter[]) changes recorded.\n")
    show(io, MIME("text/plain"), A.view)
end


###########
## Tests ##

using Test

A = ArrayWithCounter(5)
B = ArrayWithCounter(5)
B2 = ArrayWithCounter(2)
A .+= B
@test A.counter[] == 0
A[1:2, 1:2] .+= cos.(view(B2,:,:))
B3 = cos.(B2)
@test sum(A[1:2,1:2]) == sum(B3.data) && A.counter[] == count(x->x!=0, B3.data)
C = deepcopy(A)
rbc = Broadcast.Broadcasted(randn,())
A .+= rbc .+ C.*B
@test A.counter[] == length(A)+length(B2)
@test typeof(view(B,:,:) .+ 0) == ArrayWithCounter

I’m sure there is a bit nicer way instead of unpack. Probably overloading copyto! like that: Copy and broadcasting for ProductRepr by mateuszbaran · Pull Request #336 · JuliaManifolds/Manifolds.jl · GitHub would work but it’s less easy and requires additional changes to indexing into ArrayWithCounter and ArrayWithCounterView.

2 Likes

True, your version of similar avoids the use of find_x, but it does lose the counter in an expression like B = A.+1 (no broadcasting over =). One could argue, that that is the saner behavior.

I think that unpack is in this case much simpler than what happens in the PR you link. I can certainly see why that is necessary, but in my simple case one can fallback to the existing copyto! method for array.

Thanks very much for your input! I learned a lot about Julia’s broadcasting internals.

1 Like