I have a type
struct A
data::AbstractMatrix
end
If I define several methods
Base.iterate(a::A, i...) = Base.iterate(a.data, i...)
Base.broadcast(+, a::A, b::A) = A(+(a.data, b.data))
Base.length(a::A) = length(a.data)
Then I do
julia> A(rand(3,3)) .+ A(rand(3,3))
9-element Array{Float64,1}:
0.8300257464980292
1.1634911678495374
0.36783900052780516
1.3501522568416164
0.9627560365519914
1.0869740116275295
0.14735325532225652
1.4899891724904388
1.0577007328227803
julia> map(sin, A(rand(3,3)))
9-element Array{Float64,1}:
0.3849989124585183
0.7682761408014985
0.2723893707469926
0.242139980682246
0.30454653464761694
0.5890566304072592
0.3124315088158781
0.681773029698596
0.4027608840393328
As you see, the shape of data::AbstractMatrix
in A
is broken. How can I keep its shape? As just a Matrix
does:
julia> map(sin, rand(3, 3))
3Ă—3 Array{Float64,2}:
0.000690277 0.363378 0.0697955
0.587629 0.742007 0.297582
0.271736 0.411759 0.46423
The best approach is to declare your type as <: AbstractMatrix
, and define appropriate methods for broadcast
(rather than just the one for +
). See https://docs.julialang.org/en/latest/manual/interfaces/
4 Likes
Unfortunately I can’t. I have more things added to A
, e.g., 2 more fields to it.
you can still implement the interfaces for an AbstractArray
over a.data
, so your A struct effectively works as a matrix
2 Likes
That’s what I am doing. I implement Base.iterate
, Base.broadcast
, and Base.length
. But after doing a .+
and map
, the shape of A.data
breaks.
There are a number of things going on here. First is that defining:
Base.broadcast(+, a::A, b::A) = A(+(a.data, b.data))
isn’t doing what you think it’s doing. That will be called for even things like broadcast(*, A(rand(3,3)), A(rand(3,3))
! What I think you intend to do here is something like broadcast(::typeof(+), ...)
— now that will only get called when you pass the +
function directly to broadcast. But it still won’t work like you want for .+
because .+
can fuse with other operations, so it doesn’t actually call broadcast
at all!
Instead of trying to override the behavior of broadcast, much simpler is to simply make sure your type exposes the interface that broadcast uses. The easiest way to do this is indeed to make your wrapper a subtype of AbstractArray
that simply defines its indexing and size in terms of the data
array. You needn’t be an AbstractArray
, though — broadcast can work with any iterable structure. Just use the optional IteratorSize
method to say that you have a shape.
Defining iterate
isn’t “breaking” the shape — it wasn’t working at all before you defined iterate
! You just need to keep going and describe a bit more about your struct to Julia in order for it to behave like you want.
4 Likes
Thank you, I have done what I want by
struct A
data::AbstractMatrix
end
julia> Base.iterate(a::A, i...) = Base.iterate(a.data, i...)
julia> Base.length(a::A) = length(a.data)
julia> Base.broadcast(::typeof(+), a::A, b::A) = A(+(a.data, b.data))
julia> Base.IteratorSize(::A) = Base.HasShape{2}()
julia> Base.size(a::A) = size(a.data)
julia> a = A(rand(3,3))
A([0.204899 0.819475 0.963472; 0.694807 0.592251 0.562226; 0.894419 0.790337 0.196264])
julia> sin.(a)
3Ă—3 Array{Float64,2}:
0.203468 0.730788 0.821178
0.640237 0.55823 0.533071
0.779846 0.71059 0.195006
julia> A(rand(3,3)) .+ A(rand(3,3))
3Ă—3 Array{Float64,2}:
0.360492 0.542811 0.946277
1.58507 1.13376 0.864045
1.02378 0.965837 1.02973
1 Like