Is there a way to seamlessly inherit vector arithmetic on new types?

I’m sure this is doable via macro by iterating through fields of a type, though I don’t have enough fluency with metaprogramming to know how to write this macro easily. But e.g. say I am going to be able to create a bunch of types.

struct Type1
    field1::Float64
    field2::Float64
end

struct Type2
    field1::Float64
    field2::Float64
    field3::Float64
end

I could manually do

import Base.*, Base.+, Base./
function *(x::Type1,r::Real)
   Type1(x.field1*r,x.field2*r)
end
function *(r::Real,x::Type1)
    x*r
end
function +(x::Type1,y::Type1)
    Type1(x.field1+y.field1,x.field2+y.field2)
end
function -(x::Type1,y::Type1)
   x + ((-1) * y)
end
function /(x::Type1,r::Real)
    x * 1/r
end
#etc. and again for Type2

But it seems there should be a one-liner way to just say “treat Type1 and its fields like a vector over Reals” and get all this for free. (Perhaps with NamedTuples?)

Thanks!

This sounds like a good fit for StaticArrays.FieldVector: http://juliaarrays.github.io/StaticArrays.jl/stable/pages/api.html#FieldVector-1.

5 Likes

Does look relevant, but seems like might be overkill and non-performant? I’ll give it a try; thanks!

It should actually be extremely performant (as in, it’ll generate the same or better code than the manual examples you coded up).

3 Likes

Cool :+1:

Or just put an SVector inside your object instead of the 3 coordinates separately.

Alternatively to StaticArrays.FieldVector, you may use https://github.com/JuliaDiffEq/LabelledArrays.jl

The differences for vector types are mostly minor, the big difference being that StaticArrays.FieldVector types are unique to your package, whereas the type of an LabelledArrays.SLVector is derived from the field names themselves.

The LabelledArrays package generalizes to matrices as well which is pretty cool.

Both FieldVector and SLVector are likely to be extremely efficient.

2 Likes

OK–I’ve tested this and indeed StaticArrays.FieldVector is equally as fast as the manual version I wrote, and it plays nice with https://github.com/mauro3/Parameters.jl for free.

Accordingly, I think this is a better solution (for me) than LabelledArrays just because I want to keep my custom types around rather than defining static LabelledArrays.

The only thing is, the FieldVector version seems to “forget” the original type in the computations (see below).

Any thoughts on whether this is doable while preserving the original type? I realize I can cast back after the fact; perhaps that’s the best option?

Thanks!

Code:

import Base.*, Base.+, Base.-, Base./

using Parameters
using StaticArrays
using BenchmarkTools

@with_kw struct TypeOnePlain
    field1::Float64
    field2::Float64
end

@with_kw struct TypeOneFV <: FieldVector{2,Float64}
    field1::Float64
    field2::Float64
end


function +(x::TypeOnePlain,y::TypeOnePlain)
    TypeOnePlain(x.field1+y.field1,x.field2+y.field2)
end
function *(x::TypeOnePlain,r::Real)
    TypeOnePlain(x.field1*r,x.field2*r)
end
function *(r::Real,x::TypeOnePlain)
    x*r
end
function -(x::TypeOnePlain,y::TypeOnePlain)
    x + ((-1) * y)
end
function /(x::TypeOnePlain,r::Real)
    x  * (1/r)
end


x_plain = TypeOnePlain(field2=2,field1=1)
y_plain = TypeOnePlain(2,3)
r = 4.0

x_fv = TypeOneFV(field2=2,field1=1)
y_fv = TypeOneFV(2,3)

Results:



julia> @btime r * x_plain - y_plain / r
  67.119 ns (3 allocations: 96 bytes)
TypeOnePlain
  field1: Float64 3.5
  field2: Float64 7.25


julia> @btime r * x_fv - y_fv / r
  65.119 ns (3 allocations: 96 bytes)
2-element SArray{Tuple{2},Float64,1,2}:
 3.5
 7.25

julia> z_plain  = r * x_plain - y_plain / r
TypeOnePlain
  field1: Float64 3.5
  field2: Float64 7.25


julia> typeof(z_plain)
TypeOnePlain

julia> z_fv = r * x_fv - y_fv / r
2-element SArray{Tuple{2},Float64,1,2}:
 3.5
 7.25

julia> typeof(z_fv)
SArray{Tuple{2},Float64,1,2}

julia> TypeOneFV(z_fv)
2-element TypeOneFV:
 3.5
 7.25

You can control this by overloading StaticArrays.similar_type. In this case:

StaticArrays.similar_type(::Type{TypeOneFV}, ::Type{Float64}, s::Size{(2,)}) = TypeOneFV

results in

julia> z_fv = r * x_fv - y_fv / r
2-element TypeOneFV:
 3.5
 7.25
2 Likes