Dispatch on AxisArrays with Unitful axis

Hello,
I’m relying heavily on the use of Unitful units together with AxisArrays. when I try to dispatch on the specific AxisArray from the physics point of view, I end up writing a unreadable code full of where clauses, digging into the inner guts of AxisArray. My code:

using AxisArrays
using Unitful
function test(a::AxisArray{T,1,V,Tuple{A}}) where {T, V, A <:Axis{:time, AX}} where AX<:AbstractArray{Q} where Q<:Unitful.Time
    println("representing time")    
end
function test(a::AxisArray{T,1,V,Tuple{A}}) where {T, V, A <:Axis{:energy, AX}} where AX<:AbstractArray{Q} where Q<:Unitful.Energy
    println("representing energy")    
end
function test(a::AxisArray)
    println("unrecognized")    
end

atime = AxisArray(collect(1:10), Axis{:time}((0:1:9)*u"ns")) #time with correct units
aenergy = AxisArray(collect(1:10), Axis{:energy}((0:1:9)*u"eV")) #energy with correct units
anonsense = AxisArray(collect(1:10), Axis{:energy}((0:1:9)*u"ns")) #units mismatch
test(atime)
test(aenergy)
test(anonsense)

>representing time
>representing energy
>unrecognized

Is there more readable way how to specify that the axis has a given physical dimension (like time, length, …)?
I definitely need this kind of dispatch to correctly write a Plots.jl recipe to dispatch on.
And I don’t want to wrap the AxisArray into a custom struct just to do the correct dispatch.
I’d appreciate your opinion.
Petr

1 Like

How about just breaking up the where clauses in to more understandable blocks? One nice feature of Julia v0.6 is that we now have a syntax for talking about these sorts of “unionall” types:

julia> using Unitful, AxisArrays

julia> const QAxis{Name, Dim} = Axis{Name, AX} where AX <: AbstractArray{Q} where Q <: Dim;

julia> const TimeAxis = QAxis{:time, Unitful.Time};

julia> const TimeArray{T, N, V} = AxisArray{T, N, V, <:NTuple{N, TimeAxis}};

julia> const TimeVector{T, V} = TimeArray{T, 1, V};

julia> test(a::TimeVector) = println("representing time")
test (generic function with 1 method)

julia> atime = AxisArray(collect(1:10), Axis{:time}((0:1:9)*u"ns"));

julia> test(atime)
representing time

julia> const EnergyAxis = QAxis{:energy, Unitful.Energy};

julia> const EnergyArray{T, N, V} = AxisArray{T, N, V, <:NTuple{N, EnergyAxis}};

julia> const EnergyVector{T, V} = EnergyArray{T, 1, V};

julia> test(a::EnergyVector) = println("representing energy")
test (generic function with 2 methods)

julia> aenergy = AxisArray(collect(1:10), Axis{:energy}((0:1:9)*u"eV"));

julia> test(aenergy)
representing energy

Or, if that’s too complicated, you can create whatever shorthand you want:

julia> using Unitful, AxisArrays

julia> const UnitfulVector{Name, Dim, T, V} = AxisArray{T, 1, V, Tuple{A}} where A <: Axis{Name, AX} where AX <: AbstractArray{Q} where Q <: Dim
AxisArrays.AxisArray{T,1,V,Tuple{A}} where A<:AxisArrays.Axis{Name,AX} where AX<:(AbstractArray{Q,N} where N) where Q<:Dim where V where T where Dim where Name

julia> function test(a::UnitfulVector{:time, Unitful.Time, T, V}) where {T, V} 
         println("representing time")
       end
test (generic function with 1 method)

julia> atime = AxisArray(collect(1:10), Axis{:time}((0:1:9)*u"ns"));

julia> test(atime)
representing time

# if you want, you can even omit `T` and `V`:
julia> function test(a::UnitfulVector{:energy, Unitful.Energy})
         println("representing energy")
       end
test (generic function with 2 methods)

julia> aenergy = AxisArray(collect(1:10), Axis{:energy}((0:1:9)*u"eV"));

julia> test(aenergy)
representing energy

or if you don’t care about the name of the axis, but just its units:

julia> using Unitful, AxisArrays

julia> const UnitfulVector{Dim, T, V} = AxisArray{T, 1, V, Tuple{A}} where {Q <: Dim, AX <: AbstractArray{Q}, Name, A <: Axis{Name, AX}}
AxisArrays.AxisArray{T,1,V,Tuple{A}} where A<:AxisArrays.Axis{Name,AX} where Name where AX<:(AbstractArray{Q,N} where N) where Q<:Dim where V where T where Dim

julia> function test(a::UnitfulVector{Unitful.Energy})
         println("representing energy")
       end
test (generic function with 1 method)

julia> aenergy = AxisArray(collect(1:10), Axis{:energy}((0:1:9)*u"eV"));

julia> test(aenergy)
representing energy

julia> function test(a::UnitfulVector{Unitful.Time, T, V}) where {T, V} 
         println("representing time")
       end
test (generic function with 2 methods)

julia> atime = AxisArray(collect(1:10), Axis{:time}((0:1:9)*u"ns"));

julia> test(atime)
representing time
2 Likes

This are great suggestions, thank you rdeits. I didn’t know the const expression can actually represent generic types. This construct simplifies things a lot and really makes the source code readable. It is great to be able to dispatch on this and not have to define my own struct just to guide the right dispatch - I can this way use the generic AxisArrays and still write several methods to correctly dispatch on the physical meaning of them. Cool.