I’m trying to write some functions that work on arrays of 2d or 3d points, so generally speaking I’m planning to use Vector{SVector{3,T}}
but I would like to make the functions generic enough to handle a 3xN matrix form also. I know that type-wise I can use AbstractArray and they should both satisfy it. But my question is actually just about interfaces I guess.
If I want to iterate over each point of a Vector{SVector{3,T}}
I can just do for pt in pts
but if pts
is a 3xN array I would need to do for pt in eachcol(pts)
– how do I abstract iteration over these two types?
Don’t declare a type at all in your lowest-level computation routine, just take an arbitrary iterator of points. For example, you could wrap the iterator interface with functions that dispatch appropriately on vectors-of-vectors vs matrices:
function _dostuff(itr)
for pt in itr
# do stuff with pt
end
end
# vector of vectors:
dostuff(pts::AbstractVector{<:AbstractVector}) = _dostuff(pts)
# d x N array of d-vectrors:
dostuff(ptmatrix::AbstractMatrix{<:Number}) = _dostuff(eachcol(ptmatrix))
That being said, I would ordinarily recommend picking a single interface and sticking with it. For example:
-
If you need a specific data structure then just support AbstractVector{<:AbstractVector}
, rather than supporting completely different data structures.
-
Or just require the user to pass an iterator of points, in which case the user is responsible for calling eachcol
if that’s what they want.
It depends on what operations you need to perform, e.g. if you need to mutate the data in-place then you can’t accept an arbitrary iterator.
4 Likes
Ok, I actually ended up taking your advice from your first bullet and just working in terms of Vector-of-Vector and just made an adapter to convert a matrix input and then forward that view of the matrix to the original function.
1 Like
Another solution: Define eachpoint(pts)
, and use for pt in eachpoint(pts) ~ end
. The function eachpoint(pts)
could be used for other purposes as well
MWE:
Input:
using StaticArrays
eachpoint(itr) = itr
eachpoint(ptmatrix::AbstractMatrix) = eachcol(ptmatrix)
function f(pts)
for pt in eachpoint(pts)
println(pt)
end
end
f(SVector(k+1, k+2, k+3) for k in 0:3:12)
Output:
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10, 11, 12]
[13, 14, 15]
Input:
f(reshape(1:15, 3, 5))
Output:
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10, 11, 12]
[13, 14, 15]
3 Likes