Overloading `iterate` breaks the shape of my wrapper of a `AbstractMatrix`

I have a type

struct A
    data::AbstractMatrix  
end

If I define several methods

Base.iterate(a::A, i...) = Base.iterate(a.data, i...)

Base.broadcast(+, a::A, b::A) = A(+(a.data, b.data))

Base.length(a::A) = length(a.data)

Then I do

julia> A(rand(3,3)) .+ A(rand(3,3))
9-element Array{Float64,1}:
 0.8300257464980292 
 1.1634911678495374 
 0.36783900052780516
 1.3501522568416164 
 0.9627560365519914 
 1.0869740116275295 
 0.14735325532225652
 1.4899891724904388 
 1.0577007328227803 

julia> map(sin, A(rand(3,3)))
9-element Array{Float64,1}:
 0.3849989124585183 
 0.7682761408014985 
 0.2723893707469926 
 0.242139980682246  
 0.30454653464761694
 0.5890566304072592 
 0.3124315088158781 
 0.681773029698596  
 0.4027608840393328 

As you see, the shape of data::AbstractMatrix in A is broken. How can I keep its shape? As just a Matrix does:

julia> map(sin, rand(3, 3))
3Ă—3 Array{Float64,2}:
 0.000690277  0.363378  0.0697955
 0.587629     0.742007  0.297582 
 0.271736     0.411759  0.46423  

The best approach is to declare your type as <: AbstractMatrix, and define appropriate methods for broadcast (rather than just the one for +). See https://docs.julialang.org/en/latest/manual/interfaces/

4 Likes

Unfortunately I can’t. I have more things added to A, e.g., 2 more fields to it.

you can still implement the interfaces for an AbstractArray over a.data, so your A struct effectively works as a matrix

2 Likes

That’s what I am doing. I implement Base.iterate, Base.broadcast, and Base.length. But after doing a .+ and map, the shape of A.data breaks.

i was thinking of getindex, setindex, and others:

https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array-1

There are a number of things going on here. First is that defining:

Base.broadcast(+, a::A, b::A) = A(+(a.data, b.data))

isn’t doing what you think it’s doing. That will be called for even things like broadcast(*, A(rand(3,3)), A(rand(3,3))! What I think you intend to do here is something like broadcast(::typeof(+), ...) — now that will only get called when you pass the + function directly to broadcast. But it still won’t work like you want for .+ because .+ can fuse with other operations, so it doesn’t actually call broadcast at all!

Instead of trying to override the behavior of broadcast, much simpler is to simply make sure your type exposes the interface that broadcast uses. The easiest way to do this is indeed to make your wrapper a subtype of AbstractArray that simply defines its indexing and size in terms of the data array. You needn’t be an AbstractArray, though — broadcast can work with any iterable structure. Just use the optional IteratorSize method to say that you have a shape.

Defining iterate isn’t “breaking” the shape — it wasn’t working at all before you defined iterate! You just need to keep going and describe a bit more about your struct to Julia in order for it to behave like you want.

4 Likes

Thank you, I have done what I want by

struct A
    data::AbstractMatrix  
end

julia> Base.iterate(a::A, i...) = Base.iterate(a.data, i...)

julia> Base.length(a::A) = length(a.data)

julia> Base.broadcast(::typeof(+), a::A, b::A) = A(+(a.data, b.data))

julia> Base.IteratorSize(::A) = Base.HasShape{2}()

julia> Base.size(a::A) = size(a.data)

julia> a = A(rand(3,3))
A([0.204899 0.819475 0.963472; 0.694807 0.592251 0.562226; 0.894419 0.790337 0.196264])

julia> sin.(a)
3Ă—3 Array{Float64,2}:
 0.203468  0.730788  0.821178
 0.640237  0.55823   0.533071
 0.779846  0.71059   0.195006

julia> A(rand(3,3)) .+ A(rand(3,3))
3Ă—3 Array{Float64,2}:
 0.360492  0.542811  0.946277
 1.58507   1.13376   0.864045
 1.02378   0.965837  1.02973 
1 Like