How to remove one of type parameters in a method signature?

I have a bunch of method signatures like this:

sig1 = Tuple{typeof(rrule), typeof(sum), AbstractArray{T, N} where N} where T<:Number
sig2 = Tuple{typeof(rrule), typeof(mean), AbstractArray{var"#s256", N} where {var"#s256"<:Real, N}}
...

Now I want to get similar signatures but without the first type parameter, i.e.:

Tuple{typeof(sum), AbstractArray{T, N} where N} where T<:Number
Tuple{typeof(mean), AbstractArray{var"#s256", N} where {var"#s256"<:Real, N}}

I tried to do it it in-place:

function remove_first_parameter(sig)
    subsig = sig
    # unroll UnionAll wrappers
    while subsig isa UnionAll
        subsig = subsig.body
    end
    subsig.parameters = Core.svec(subsig.parameters[2:end])
    return sig
end

But it corrupts the method table, and deepcopy() doesn’t work on types.

I managed to compose a function to remove first parameter like this:

remove_first_parameter(::Type{Tuple{T1, T2, T3}}) where {T1, T2, T3} = Tuple{T2, T3}

It works with sig2 which is a DataType, but not with sig1 which is UnionAll.

Is there a method which works with UnionAll as well?

Tuple{Base.tail(fieldtypes(sig1))...}

Unfortunately, this doesn’t work for UnionAll signatures turning them into plain Tuples:

julia> dump(sig1)
UnionAll
  var: TypeVar
    name: Symbol T
    lb: Union{}
    ub: Number <: Any
  body: Tuple{typeof(rrule), typeof(sum), AbstractArray{T<:Number, N} where N} <: Any

julia> dump(Tuple{Base.tail(fieldtypes(sig2))...})
Tuple{typeof(mean), AbstractArray{var"#s256", N} where {var"#s256"<:Real, N}} <: Any

I’m currently testing the following implementation which addresses specifically this issue:

function remove_first_parameter(sig)
    if sig isa UnionAll
        new_body = remove_first_parameter(sig.body)
        return UnionAll(sig.var, new_body)
    elseif sig isa DataType
        params = sig.parameters
        return Tuple{params[2:end]...}
    else
        error("Unsupported type: $sig")
    end
end

Here’s an example:

julia> dump(sig1)
UnionAll
  var: TypeVar
    name: Symbol T
    lb: Union{}
    ub: Number <: Any
  body: Tuple{typeof(rrule), typeof(sum), AbstractArray{T<:Number, N} where N} <: Any

julia> dump(remove_first_parameter(sig1))
UnionAll
  var: TypeVar
    name: Symbol T
    lb: Union{}
    ub: Number <: Any
  body: Tuple{typeof(sum), AbstractArray{T<:Number, N} where N} <: Any

I believe there’s a more elegant way to write it, but properly dispatching on UnionAll, Tuple{T} and friends is sometimes pretty tricky, so I stopped on this simple solution and so far it works fine.

1 Like