# Overloading `iterate` breaks the shape of my wrapper of a `AbstractMatrix`

I have a type

``````struct A
data::AbstractMatrix
end
``````

If I define several methods

``````Base.iterate(a::A, i...) = Base.iterate(a.data, i...)

Base.broadcast(+, a::A, b::A) = A(+(a.data, b.data))

Base.length(a::A) = length(a.data)
``````

Then I do

``````julia> A(rand(3,3)) .+ A(rand(3,3))
9-element Array{Float64,1}:
0.8300257464980292
1.1634911678495374
0.36783900052780516
1.3501522568416164
0.9627560365519914
1.0869740116275295
0.14735325532225652
1.4899891724904388
1.0577007328227803

julia> map(sin, A(rand(3,3)))
9-element Array{Float64,1}:
0.3849989124585183
0.7682761408014985
0.2723893707469926
0.242139980682246
0.30454653464761694
0.5890566304072592
0.3124315088158781
0.681773029698596
0.4027608840393328
``````

As you see, the shape of `data::AbstractMatrix` in `A` is broken. How can I keep its shape? As just a `Matrix` does:

``````julia> map(sin, rand(3, 3))
3×3 Array{Float64,2}:
0.000690277  0.363378  0.0697955
0.587629     0.742007  0.297582
0.271736     0.411759  0.46423
``````

The best approach is to declare your type as `<: AbstractMatrix`, and define appropriate methods for `broadcast` (rather than just the one for `+`). See https://docs.julialang.org/en/latest/manual/interfaces/

4 Likes

Unfortunately I can’t. I have more things added to `A`, e.g., 2 more fields to it.

you can still implement the interfaces for an `AbstractArray` over `a.data`, so your A struct effectively works as a matrix

2 Likes

That’s what I am doing. I implement `Base.iterate`, `Base.broadcast`, and `Base.length`. But after doing a `.+` and `map`, the shape of `A.data` breaks.

i was thinking of getindex, setindex, and others:

https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array-1

There are a number of things going on here. First is that defining:

``````Base.broadcast(+, a::A, b::A) = A(+(a.data, b.data))
``````

isn’t doing what you think it’s doing. That will be called for even things like `broadcast(*, A(rand(3,3)), A(rand(3,3))`! What I think you intend to do here is something like `broadcast(::typeof(+), ...)` — now that will only get called when you pass the `+` function directly to broadcast. But it still won’t work like you want for `.+` because `.+` can fuse with other operations, so it doesn’t actually call `broadcast` at all!

Instead of trying to override the behavior of broadcast, much simpler is to simply make sure your type exposes the interface that broadcast uses. The easiest way to do this is indeed to make your wrapper a subtype of `AbstractArray` that simply defines its indexing and size in terms of the `data` array. You needn’t be an `AbstractArray`, though — broadcast can work with any iterable structure. Just use the optional `IteratorSize` method to say that you have a shape.

Defining `iterate` isn’t “breaking” the shape — it wasn’t working at all before you defined `iterate`! You just need to keep going and describe a bit more about your struct to Julia in order for it to behave like you want.

4 Likes

Thank you, I have done what I want by

``````struct A
data::AbstractMatrix
end

julia> Base.iterate(a::A, i...) = Base.iterate(a.data, i...)

julia> Base.length(a::A) = length(a.data)

julia> Base.broadcast(::typeof(+), a::A, b::A) = A(+(a.data, b.data))

julia> Base.IteratorSize(::A) = Base.HasShape{2}()

julia> Base.size(a::A) = size(a.data)

julia> a = A(rand(3,3))
A([0.204899 0.819475 0.963472; 0.694807 0.592251 0.562226; 0.894419 0.790337 0.196264])

julia> sin.(a)
3×3 Array{Float64,2}:
0.203468  0.730788  0.821178
0.640237  0.55823   0.533071
0.779846  0.71059   0.195006

julia> A(rand(3,3)) .+ A(rand(3,3))
3×3 Array{Float64,2}:
0.360492  0.542811  0.946277
1.58507   1.13376   0.864045
1.02378   0.965837  1.02973
``````
1 Like