After some experimentation, I have come up with a solution that satisfies me and which I deem extendable.
The problem splits into four parts:
- Broadcasting over
ArrayWithCounter
- View into
ArrayWithCounter
, i.e. ArrayWithCounterView
- Broadcasting over
ArrayWithCounterView
- Interoperability of both types.
This is what I came up with
import Base: axes, broadcastable, copyto!, eltype, getindex, maybeview, setindex!, similar, size, length, show, view
import Base.Broadcast
######################
## ArrayWithCounter ##
struct ArrayWithCounter
counter::Ref{Int64}
data::Array{Float64, 2}
end
ArrayWithCounter(n) = ArrayWithCounter(0, fill(0.0, n,n))
eltype(A::ArrayWithCounter) = eltype(A.data)
size(A::ArrayWithCounter, i...) = size(A.data,i...)
length(A::ArrayWithCounter, i...) = length(A.data,i...)
axes(A::ArrayWithCounter) = axes(A.data)
similar(A::ArrayWithCounter) = ArrayWithCounter(A.counter[], similar(A.data))
getindex(A::ArrayWithCounter, i...) = A.data[i...]
setindex!(A::ArrayWithCounter, v, i::CartesianIndex{2}) = setindex!(A, v, Tuple(i)...)
function setindex!(A::ArrayWithCounter, v, i::Vararg{Int64})
if A.data[i] == v
return v
end
A.counter[] += 1
setindex!(A.data, v, i...)
end
broadcastable(x::ArrayWithCounter) = x
Base.BroadcastStyle(::Type{ArrayWithCounter}) = Broadcast.Style{ArrayWithCounter}()
# Custom broadcast style should take precedence over ArrayStyle to enable custom `copyto!` method below.
Base.BroadcastStyle(::Broadcast.Style{ArrayWithCounter}, ::Broadcast.AbstractArrayStyle) = Broadcast.Style{ArrayWithCounter}()
function similar(bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}, ::Type{ElType}) where ElType
A = find_awc(bc)
ArrayWithCounter(A.counter[], similar(A.data, ElType)) # keep the old counter
end
"""`A = find_awc(As)` returns the first ArrayWithCounter among the arguments of a Broadcasted object
Lifted from the Julia documentation
https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting
"""
find_awc(bc::Base.Broadcast.Broadcasted) = find_awc(bc.args)
find_awc(args::Tuple) = find_awc(find_awc(args[1]), Base.tail(args))
find_awc(x) = x
find_awc(::Tuple{}) = nothing
find_awc(a::ArrayWithCounter, rest) = a
find_awc(::Any, rest) = find_awc(rest)
copyto!(dest::ArrayWithCounter, bc::Broadcast.Broadcasted{T}) where {T<:Broadcast.AbstractArrayStyle} = _copyto!(dest, bc)
copyto!(dest::ArrayWithCounter, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}) = _copyto!(dest, bc)
function _copyto!(dest::ArrayWithCounter, bc)
old_data = copy(dest.data)
copyto!(dest.data, unpack(bc))
changed = count(i->old_data[i]!=dest.data[i], eachindex(dest.data))
dest.counter[] += changed
dest
end
function copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}})
copyto!(dest, unpack(bc))
dest
end
"""Recursively replace occurences of `A::ArrayWithCounter` with `A::data` in any (nested)
`Broadcasted` object in order defer `copyto!` to the existing methods for arrays.
Copied from https://discourse.julialang.org/t/only-specializing-broadcast-on-likewise-types/20117
"""
@inline unpack(bc::Broadcast.Broadcasted) = Broadcast.Broadcasted(bc.f, unpack_args(bc.args))
unpack(x) = x
unpack(x,::Any) = x
unpack(x::ArrayWithCounter) = x.data
@inline unpack_args(::Tuple{}) = ()
@inline unpack_args(args::Tuple) = (unpack(args[1]), unpack_args(Base.tail(args))...)
unpack_args(args::Tuple{Any}) = (unpack(args[1]),)
unpack_args(::Any, args::Tuple{}) = ()
##########################
## ArrayWithCounterView ##
struct ArrayWithCounterView
parent::ArrayWithCounter
view::SubArray
end
view(A::ArrayWithCounter, i...) = ArrayWithCounterView(A, view(A.data, i...))
maybeview(A::ArrayWithCounter, args...) = view(A, args...)
size(A::ArrayWithCounterView, i...) = size(A.view, i...)
length(A::ArrayWithCounterView) = length(A.view)
eltype(A::ArrayWithCounterView) = eltype(A.parent)
similar(A::ArrayWithCounterView) = similar(A.parent)
function similar(bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}}, T::Type{ElType}) where ElType
A = find_awcv(bc)
similar(A.parent)
end
find_awcv(bc::Base.Broadcast.Broadcasted) = find_awcv(bc.args)
find_awcv(args::Tuple) = find_awcv(find_awcv(args[1]), Base.tail(args))
find_awcv(x) = x
find_awcv(::Tuple{}) = nothing
find_awcv(::Any, rest) = find_awcv(rest)
find_awcv(a::ArrayWithCounterView, rest) = a
broadcastable(A::ArrayWithCounterView) = A
Base.BroadcastStyle(::Type{ArrayWithCounterView}) = Broadcast.Style{ArrayWithCounterView}()
Base.BroadcastStyle(::Broadcast.Style{ArrayWithCounterView}, ::Broadcast.AbstractArrayStyle) = Broadcast.Style{ArrayWithCounterView}()
Base.BroadcastStyle(::Broadcast.Style{ArrayWithCounterView}, ::Broadcast.Style{ArrayWithCounter}) = Broadcast.Style{ArrayWithCounter}()
function copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}})
copyto!(dest, unpack(bc))
end
unpack(x::ArrayWithCounterView) = x.view
function copyto!(dest::ArrayWithCounter, bc::Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}})
copyto!(dest, unpack(bc))
end
function copyto!(dest::ArrayWithCounterView, bc::Union{Broadcast.Broadcasted{T}, Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounterView}},Broadcast.Broadcasted{Broadcast.Style{ArrayWithCounter}}}) where {T<:Broadcast.AbstractArrayStyle}
_copyto!(dest, bc)
end
function _copyto!(dest, bc)
old_data = copy(dest.view)
copyto!(dest.view, bc)
changed = count(i->old_data[i]!=dest.view[i], eachindex(dest.view))
dest.parent.counter[] += changed
dest
end
## pretty printing ##
function show(io::IO, ::MIME"text/plain", A::ArrayWithCounter)
print(io, "$(typeof(A)) with $(A.counter[]) changes recorded.\n")
show(io, MIME("text/plain"), A.data)
end
function show(io::IO, ::MIME"text/plain", A::ArrayWithCounterView)
print(io, "view into an $(typeof(A.parent)) with $(A.parent.counter[]) changes recorded.\n")
show(io, MIME("text/plain"), A.view)
end
###########
## Tests ##
using Test
A = ArrayWithCounter(5)
B = ArrayWithCounter(5)
B2 = ArrayWithCounter(2)
A .+= B
@test A.counter[] == 0
A[1:2, 1:2] .+= cos.(view(B2,:,:))
B3 = cos.(B2)
@test sum(A[1:2,1:2]) == sum(B3.data) && A.counter[] == count(x->x!=0, B3.data)
C = deepcopy(A)
rbc = Broadcast.Broadcasted(randn,())
A .+= rbc .+ C.*B
@test A.counter[] == length(A)+length(B2)
@test typeof(view(B,:,:) .+ 0) == ArrayWithCounter
It has become rather extensive, but you can mix and match basically anything in broadcasts.
If there are suggestions for improvement, I’d be happy to hear them. In particular, I find the whole unpack
and find_x
story a little cumbersome. Mostly because it is boiler plate that will appear again and again in any custom broadcasting.