Using @nospecialize on a set of subtypes

Hi all,

I have a function g(a,b) with many methods based on input types of a and b. A subset of the methods have common code, eg:

g(a::T1, b::T2) where {T1<:AbstractA,T2<:AbstractB} = (fa(a),fb(b))

where fa and fb will specialize on the input type, but clearly g does not need to specialize on the input type, other than to ensure that a is a subtype of AbstractA and b is a subtype of AbstractB.

I’d like to tell the compiler to just compile g once, rather than compiling for every possible combination of subtypes of AbstractA and AbstractB. My understanding is that I do this using @nospecialize. However, I can’t seem to get it to work properly. Specifically the following does work:

g(@nospecialize(a), @nospecialize(b)) = (fa(a), fb(b))

but this does not restrict a to be a subtype of AbstractA and b to be a subtype of AbstractB. I could also try the following:

g(@nospecialize(a::T1), @nospecialize(b::T2)) where {T1<:AbstractA,T2<:AbstractB} = (fa(a),fb(b))

but based on my MWE below this appears to compile for every combination of subtype of AbstractA and AbstractB.

Apologies, there is probably a smarter way to test this but I don’t know what it is. I had to resort to building 10 subtypes of abstract types AbA, AbB, and AbC, and then looping over calls to a trivial function for every combination of subtypes. I did this for four different functions below f1, f2, f3, and f4, and hoped that the loop over calls to f3 and f4 would be fast. In practice, only the loop over f4 is fast. Any ideas how I can speed up the f3 case so it works like the f4 case, while preserving the subtype restrictions? Code below:

let
    K = 10;
    abstract type AbA ; end
    for ka = 1:K
        namesymbol = Symbol("A$(ka)")
        @eval struct $(namesymbol) <: AbA ; x::Int ; end
        @eval $(Symbol("fa"))(x::$(namesymbol))::Int = x.x + $(ka)
    end
    abstract type AbB ; end
    for kb = 1:K
        namesymbol = Symbol("B$(kb)")
        @eval struct $(namesymbol) <: AbB ; x::Int ; end
        @eval $(Symbol("fb"))(x::$(namesymbol))::Int = x.x + $(kb)
    end
    abstract type AbC ; end
    for kc = 1:K
        namesymbol = Symbol("C$(kc)")
        @eval struct $(namesymbol) <: AbC ; x::Int ; end
        @eval $(Symbol("fc"))(x::$(namesymbol))::Int = x.x + $(kc)
    end
    function f1(a::Ta, b::Tb, c::Tc) where {Ta<:AbA,Tb<:AbB,Tc<:AbC}
        return fa(a) + fb(b) + fc(c)
    end
    function f2(a, b, c)
        return fa(a) + fb(b) + fc(c)
    end
    function f3(@nospecialize(a::Ta), @nospecialize(b::Tb), @nospecialize(c::Tc))::Float64 where {Ta<:AbA,Tb<:AbB,Tc<:AbC}
        return fa(a) + fb(b) + fc(c)
    end
    function f4(@nospecialize(a), @nospecialize(b), @nospecialize(c))
        return fa(a) + fb(b) + fc(c)
    end
    println("working on f1")
    v1 = fill(0, K*K*K);
    for ka = 1:K ; for kb = 1:K ; for kc = 1:K
        a = @eval $(Symbol("A$(ka)"))(rand(1:10))
        b = @eval $(Symbol("B$(kb)"))(rand(1:10))
        c = @eval $(Symbol("C$(kc)"))(rand(1:10))
        v1[ka*kb*kc] = f1(a, b, c)
    end ; end ; end
    println("working on f2")
    v2 = fill(0, K*K*K);
    for ka = 1:K ; for kb = 1:K ; for kc = 1:K
        a = @eval $(Symbol("A$(ka)"))(rand(1:10))
        b = @eval $(Symbol("B$(kb)"))(rand(1:10))
        c = @eval $(Symbol("C$(kc)"))(rand(1:10))
        v2[ka*kb*kc] = f2(a, b, c)
    end ; end ; end
    println("working on f3")
    v3 = fill(0, K*K*K);
    for ka = 1:K ; for kb = 1:K ; for kc = 1:K
        a = @eval $(Symbol("A$(ka)"))(rand(1:10))
        b = @eval $(Symbol("B$(kb)"))(rand(1:10))
        c = @eval $(Symbol("C$(kc)"))(rand(1:10))
        v3[ka*kb*kc] = f3(a, b, c)
    end ; end ; end
    println("working on f4")
    v4 = fill(0, K*K*K);
    for ka = 1:K ; for kb = 1:K ; for kc = 1:K
        a = @eval $(Symbol("A$(ka)"))(rand(1:10))
        b = @eval $(Symbol("B$(kb)"))(rand(1:10))
        c = @eval $(Symbol("C$(kc)"))(rand(1:10))
        v4[ka*kb*kc] = f4(a, b, c)
    end ; end ; end
end

Because the TypeVars T1 and T2 connect to something outside the @nospecialize scope you’ve essentially defeated the mechanism. If you write g as

g(@nospecialize(a::AbstractA), @nospecialize(b::AbstractB)) = ...

then I think you’ll get what you’re aiming for.

3 Likes

It feels like

julia> @nospecialize g(a::T1, b::T2) where {T1<:AbstractA,T2<:AbstractB} = (fa(a),fb(b))

should work, but worryingly it doesn’t do anything. It doesn’t define a method at all, and does not throw an error.

2 Likes

Yes that absolutely works thank you. Gah, I feel like I should have been able to work that one out on my own :slight_smile:

1 Like

Yes, I plugged:

@nospecialize function f6(a::AbA, b::AbB, c::AbC)
    return fa(a) + fb(b) + fc(c)
end

into my MWE and get the same behaviour. Nothing until an UndefVarError when I first try to call f6. Feels odd…