Restricted parametric type

hi, how could I define a parametric type where only a subset of types are allowed? e.g.

struct MyType{S, T}
    content::Int64
end

how could I put some arbitrary restrictions like:

MyType{Float64, Int64}(123)
MyType{Int64, Float64}(123)

are valid, and

MyType{Float64, Float64}(123)
MyType{Int64, Int64}(123)

are invalid?

thanks.

1 Like

There are two approaches to constraining a struct’s field’s types. One involves constraining struct parameters and using those parameters for field types. The other involves constraining the types that struct constructors accept. They may be used separately or together.

If you have two fields and one should be a floating point value and the other should be an integer value:

struct Demo{A<:AbstractFloat, B<:Integer}
    floatfield::A
    intfield::B
end

Demo(12.0, 5) # works
Demo(12, 5) # does not work

If you want one field to hold a Float64 value and the other field to hold an Int64 value, and sometimes you want the first field to be one type and other times you want the first field to be the other type

const IntFloat64 = Union{Int64, Float64}

struct Demo{A, B}
    field1::A
    field2::B

    function Demo(field1::A, field2::B) where {A<:IntFloat64, B<:IntFloat64}
        # ensure the types differ
        if A == B
           throw(ErrorException("types must differ"))
        end
        return new{A,B}(field1, field2)
    end
end

When the constructor is inside the struct (an inner constructor), and there are no outer constructors defined, then all construction will use the inner constructor.

3 Likes

It does not look like he wants to restrict the structs field types.

struct MyType{S, T}
    content::Int64
end

function MyType(::Type{S}, ::Type{T}, content::Int) where {S, T}
    constraintypes(S, T)
    return MyType{S, T}(content)
end

function constraintypes(::Type{S}, ::Type{T}) where {S, T}
      s = supertype(S)
      t = supertype(T)
      if s<:t || t<:s
          throw(ErrorException("types must differ more"))
     end
     return nothing
end
julia>MyType(Int, Float64, 5)
MyType{Int64,Float64}(5)

julia> MyType(Float32, Int16, 5)
MyType{Float32,Int16}(5)

julia> MyType(Float32, Float64, 5)
ERROR: types must differ more
2 Likes

oic: the trick is to use ::Type{} ?

actually I was looking into the code of StaticArray. Seems like it does not restrict the definition of nonsense types, e.g.:

julia> SMatrix{2, 3}
SArray{Tuple{2,3},T,2,L} where L where T

julia> SMatrix{-1, 3}
SArray{Tuple{-1,3},T,2,L} where L where T

noted that the 2nd type throws no error nor warning although it’s nonsense (to have negative dimension).

julia> SMatrix{2, 3}(11.0, 12.0, 13.0, 14.0, 15.0, 16.0)
2Γ—3 SArray{Tuple{2,3},Float64,2,6}:
 11.0  13.0  15.0
 12.0  14.0  16.0

julia> SMatrix{-2, 3}(11.0, 12.0, 13.0, 14.0, 15.0, 16.0)
ERROR: ArgumentError: Size mismatch in Static Array parameters. Got size Tuple{-2,3}, dimension 2 and length 6.

error is given only when calling constructors.

I understand your perspective and have no quarrel with it. We do it like this a lot:

When used, argument validation tends to be encapsulated in a validation function
or given as an inline test: arg1 > 0 || throw(DomainError("$arg1")).

Type signatures are used to guide multidispatch and also to gate arg types.

Mostly, Julia code is written to the type system rather than to defend it.

1 Like

inspecting SMatrix gives me a big surprise (due to my ignorance): the β€œparameter” of a parametric type can be anything, i.e. no need to be a Type !!!

struct MyStruct1{T}
    content::Int64
end

function f(x::MyStruct1{T}) where {T}
    println("T is ", T, "   content is: ", x.content)
end

julia> f(MyStruct1{1}(123) )
T is 1   content is: 123

julia> f(MyStruct1{2.2}(222) )
T is 2.2   content is: 222

julia> f(MyStruct1{:sym}(333) )
T is sym   content is: 333

julia> typeof(MyStruct1{1}(123) )
MyStruct1{1}

julia> typeof(MyStruct1{2.2}(222) )
MyStruct1{2.2}

julia> typeof(MyStruct1{:sym}(333) )
MyStruct1{:sym}

so, I do not understand why we need Val{} at all???

Not quite, fourth bullet point in Types Β· The Julia Language.

For convenience! Of course, you can define your own Val-like type, but if you only want to dispatch on values, nothing more, someone already defined the type for you.

if you only want to dispatch on values, nothing more

I don’t understand. For example, the following use the value directly to dispatch and it’s all fine:

struct MyStruct1{T}
    content::Int64
end

function f1(x::MyStruct1{1})
    println("T is 1")
end

function f1(x::MyStruct1{2})
    println("T is 2")
end

julia> f1(MyStruct1{1}(123) )
T is 1

julia> f1(MyStruct1{2}(222) )
T is 2

could you give an example like Val{1234} is necessary rather then the value 1234?

It is mostly for type stability. Some types have meaningful numbers in their parameters for example NTuple{N, T} where N is the number of elements in the tuple. Let’s say I want to generate 5 3-tuples and perhaps return a Scatter struct with points field.

julia> struct Scatter{N, T}
           points::Vector{NTuple{N, T}}
       end

julia> Scatter(N, n) = Scatter([ntuple(i->rand(), N) for i in 1:n])
Scatter

julia> Scatter(::Val{N}, n) where {N} = Scatter([ntuple(i->rand(), N) for i in 1:n])
Scatter

julia> Scatter(3, 5)
Scatter{3,Float64}(Tuple{Float64,Float64,Float64}[(0.599768, 0.863335, 0.0522125), (0.461746, 0.991876, 0.349595), (0.976401, 0.146929, 0.393038), (0.570525, 0.752721, 0.00339558), (0.0793879, 0.612754, 0.00491688)])

julia> @code_warntype Scatter(3, 5)
Body::Scatter{_1,_2} where _2 where _1
1 1 ─ %1 = %new(getfield(Main, Symbol("##3#5")){Int64}, N)::getfield(Main, Symbol("##3#5")){Int64}            β”‚
  β”‚   %2 = (Base.sle_int)(1, n)::Bool                                                                         β”‚β•»β•·β•·β•· Colon
  β”‚        (Base.sub_int)(n, 1)                                                                               β”‚β”‚β•»    Type
  β”‚   %4 = (Base.ifelse)(%2, n, 0)::Int64                                                                     │││┃    unitrange_last
  β”‚   %5 = %new(UnitRange{Int64}, 1, %4)::UnitRange{Int64}                                                    β”‚β”‚β”‚
  β”‚   %6 = %new(Base.Generator{UnitRange{Int64},getfield(Main, Symbol("##3#5")){Int64}}, %1, %5)::Base.Generator{UnitRange{Int64},getfield(Main, Symbol("##3#5")){Int64}}
  β”‚   %7 = invoke Base.collect(%6::Base.Generator{UnitRange{Int64},getfield(Main, Symbol("##3#5")){Int64}})::Array{_1,1} where _1
  β”‚   %8 = (Main.Scatter)(%7)::Scatter{_1,_2} where _2 where _1                                               β”‚
  └──      return %8                                                                                          β”‚

julia> Scatter(Val{3}(), 5)
Scatter{3,Float64}(Tuple{Float64,Float64,Float64}[(0.414659, 0.267734, 0.0170115), (0.763122, 0.858378, 0.45464), (0.0388137, 0.489361, 0.0495321), (0.0409206, 0.00572096, 0.924081), (0.350058, 0.133318, 0.890979)])

julia> @code_warntype Scatter(Val{3}(), 5)
Body::Scatter{3,Float64}
1 1 ─ %1 = (Base.sle_int)(1, n)::Bool                                                                        β”‚β•»β•·β•·β•·β•· Colon
  β”‚        (Base.sub_int)(n, 1)                                                                              β”‚β”‚β•»     Type
  β”‚   %3 = (Base.ifelse)(%1, n, 0)::Int64                                                                    │││┃     unitrange_last
  β”‚   %4 = %new(UnitRange{Int64}, 1, %3)::UnitRange{Int64}                                                   β”‚β”‚β”‚
  β”‚   %5 = %new(Base.Generator{UnitRange{Int64},getfield(Main, Symbol("##7#9")){3}}, getfield(Main, Symbol("##7#9")){3}(), %4)::Base.Generator{UnitRange{Int64},getfield(Main, Symbol("##7#9")){3}}
  β”‚   %6 = invoke Base.collect(%5::Base.Generator{UnitRange{Int64},getfield(Main, Symbol("##7#9")){3}})::Array{Tuple{Float64,Float64,Float64},1}
  β”‚   %7 = %new(Scatter{3,Float64}, %6)::Scatter{3,Float64}                                                  β”‚β”‚β•»     Type
  └──      return %7                                                                                         β”‚

If this code is in a hot part of the program, this type instability will propagate and slow down your whole program.

3 Likes

thanks.

here I finally found the documentation also. It’s kind of difficult to understand though :sweat_smile: