Specify constraints on type parameters

For example, I have 4 labels for parameterizing a type:

const POOL = Union{:A, :B, :C, :D}

const CANNOT_BE_TOGETHER = (
    Union{:A, :C},
    Union{:B, :D}
)

There are 4 constraints on the type parameters:

  1. All type parameters must be in POOL.
  2. The order of A and B doesn’t matter.
  3. A and B cannot be the same in one instance.
  4. Some parameters cannot appear in the mean time (as in CANNOT_BE_TOGETHER).

Here is my current implementation:

struct MyType{A, B}
    values::Matrix
    function MyType{A, B}(values) where {A, B}
        Union{A, B} in POOL || error("Unrecognized label!")
        A == B && error("$A and $B can't be the same!")
        Union{A, B} in CANNOT_BE_TOGETHER && error("$A and $B cannot appear together!")
        new(values)
    end
end

Is there a simpler, or more elegant way of specifying this? Thank you!

Checking the kind of constraints you cannot express via the type system in the inner constructor is the recommended way.

I am sorry, I do not quite understand this sentence. So it is not possible to improve it?

You are using an inner constructor, which is the only way to do this kind of check. So in that sense, it is the best solution.

OTOH, the implementation is not right, since you can only Union types, so

julia> const POOL = Union{:A, :B, :C, :D}
ERROR: TypeError: in Union, expected Type, got Symbol
Stacktrace:
 [1] top-level scope at none:0

won’t work.

I would just factor out the validation to a function I can test separately, eg

using ArgCheck: @argcheck

const POOL = (:A, :B, :C, :D)

const CANNOT_BE_TOGETHER = (Set((:A, :C)), Set((:B, :D)))

function check_valid_AB(A, B)
    @argcheck A ∈ POOL
    @argcheck B ∈ POOL
    @argcheck A ≠ B
    @argcheck Set([A, B]) ∉ CANNOT_BE_TOGETHER
end

struct MyType{A, B}
    values::Matrix
    function MyType{A, B}(values) where {A, B}
        check_valid_AB(A, B)
        new(values)
    end
end

then

julia> MyType{:C, :A}(ones(2,3))
ERROR: ArgumentError: Set([A, B]) ∉ CANNOT_BE_TOGETHER must hold. Got
Set([A, B]) => Set(Symbol[:A, :C])
CANNOT_BE_TOGETHER => (Set(Symbol[:A, :C]), Set(Symbol[:D, :B]))
Stacktrace:
 [1] macro expansion at /home/tamas/.julia/packages/ArgCheck/BUMkA/src/checks.jl:165 [inlined]
 [2] check_valid_AB(::Symbol, ::Symbol) at ./REPL[151]:5
 [3] MyType{:C,:A}(::Array{Float64,2}) at ./REPL[152]:4
 [4] top-level scope at none:0
4 Likes

If you only have a few constant parameters, there’s a dispatch-based solution that’s far more efficient:

const POOL = (:A, :B, :C, :D)
const ILLEGAL = [Set((:A, :C)), Set((:B, :D))]

struct MyOtherType{X,Y,T}
    values::T

    function MyOtherType{X,Y}(values::T) where {X,Y,T<:AbstractMatrix}
        verify(X, Y)
        new{X,Y,T}(values)
    end
end

verify(x::Symbol, y::Symbol) = verify(Val{x}(), Val{y}())

verify(x::Val, y::Val) = throw(ArgumentError("($(typeof(x).parameters[1]), $(typeof(y).parameters[1])) not allowed"))

for t1 = POOL, t2 = POOL
    (t1 ≡ t2 || Set((t1, t2)) ∈ ILLEGAL) && continue
    s1, s2 = "$t1", "$t2"
    @eval begin
        verify(x::Val{Symbol($s1)}, y::Val{Symbol($s2)}) = nothing
    end
end

This works by dynamically creating, at compile time, an empty verify method for each valid combination. If you have a very large number of combinations, this is probably not a good idea, but in a case like yours, with only 8 combinations, it’s fine.

Unrelated, I also changed values to be fully type-specified.

Testing it:

julia> M = ones(2,3);

julia> MyOtherType{:A, :A}(M)
ERROR: ArgumentError: (A, A) not allowed

julia> MyOtherType{:C, :A}(M)
ERROR: ArgumentError: (C, A) not allowed

julia> MyOtherType{:A, :X}(M)
ERROR: ArgumentError: (A, X) not allowed

julia> MyOtherType{:A, :B}(M)
MyOtherType{:A,:B,Array{Float64,2}}([1.0 1.0 1.0; 1.0 1.0 1.0])

Comparing the performance to the previous implementation:

julia> @btime MyType{:A, :B}($M)
  263.465 ns (7 allocations: 608 bytes)
MyType{:A,:B}([1.0 1.0 1.0; 1.0 1.0 1.0])

julia> @btime MyOtherType{:A, :B}($M)
  6.345 ns (1 allocation: 16 bytes)
MyOtherType{:A,:B,Array{Float64,2}}([1.0 1.0 1.0; 1.0 1.0 1.0])
3 Likes