Hi all,
I have a function g(a,b) with many methods based on input types of a and b. A subset of the methods have common code, eg:
g(a::T1, b::T2) where {T1<:AbstractA,T2<:AbstractB} = (fa(a),fb(b))
where fa and fb will specialize on the input type, but clearly g does not need to specialize on the input type, other than to ensure that a is a subtype of AbstractA and b is a subtype of AbstractB.
I’d like to tell the compiler to just compile g once, rather than compiling for every possible combination of subtypes of AbstractA and AbstractB. My understanding is that I do this using @nospecialize. However, I can’t seem to get it to work properly. Specifically the following does work:
g(@nospecialize(a), @nospecialize(b)) = (fa(a), fb(b))
but this does not restrict a to be a subtype of AbstractA and b to be a subtype of AbstractB. I could also try the following:
g(@nospecialize(a::T1), @nospecialize(b::T2)) where {T1<:AbstractA,T2<:AbstractB} = (fa(a),fb(b))
but based on my MWE below this appears to compile for every combination of subtype of AbstractA and AbstractB.
Apologies, there is probably a smarter way to test this but I don’t know what it is. I had to resort to building 10 subtypes of abstract types AbA, AbB, and AbC, and then looping over calls to a trivial function for every combination of subtypes. I did this for four different functions below f1, f2, f3, and f4, and hoped that the loop over calls to f3 and f4 would be fast. In practice, only the loop over f4 is fast. Any ideas how I can speed up the f3 case so it works like the f4 case, while preserving the subtype restrictions? Code below:
let
K = 10;
abstract type AbA ; end
for ka = 1:K
namesymbol = Symbol("A$(ka)")
@eval struct $(namesymbol) <: AbA ; x::Int ; end
@eval $(Symbol("fa"))(x::$(namesymbol))::Int = x.x + $(ka)
end
abstract type AbB ; end
for kb = 1:K
namesymbol = Symbol("B$(kb)")
@eval struct $(namesymbol) <: AbB ; x::Int ; end
@eval $(Symbol("fb"))(x::$(namesymbol))::Int = x.x + $(kb)
end
abstract type AbC ; end
for kc = 1:K
namesymbol = Symbol("C$(kc)")
@eval struct $(namesymbol) <: AbC ; x::Int ; end
@eval $(Symbol("fc"))(x::$(namesymbol))::Int = x.x + $(kc)
end
function f1(a::Ta, b::Tb, c::Tc) where {Ta<:AbA,Tb<:AbB,Tc<:AbC}
return fa(a) + fb(b) + fc(c)
end
function f2(a, b, c)
return fa(a) + fb(b) + fc(c)
end
function f3(@nospecialize(a::Ta), @nospecialize(b::Tb), @nospecialize(c::Tc))::Float64 where {Ta<:AbA,Tb<:AbB,Tc<:AbC}
return fa(a) + fb(b) + fc(c)
end
function f4(@nospecialize(a), @nospecialize(b), @nospecialize(c))
return fa(a) + fb(b) + fc(c)
end
println("working on f1")
v1 = fill(0, K*K*K);
for ka = 1:K ; for kb = 1:K ; for kc = 1:K
a = @eval $(Symbol("A$(ka)"))(rand(1:10))
b = @eval $(Symbol("B$(kb)"))(rand(1:10))
c = @eval $(Symbol("C$(kc)"))(rand(1:10))
v1[ka*kb*kc] = f1(a, b, c)
end ; end ; end
println("working on f2")
v2 = fill(0, K*K*K);
for ka = 1:K ; for kb = 1:K ; for kc = 1:K
a = @eval $(Symbol("A$(ka)"))(rand(1:10))
b = @eval $(Symbol("B$(kb)"))(rand(1:10))
c = @eval $(Symbol("C$(kc)"))(rand(1:10))
v2[ka*kb*kc] = f2(a, b, c)
end ; end ; end
println("working on f3")
v3 = fill(0, K*K*K);
for ka = 1:K ; for kb = 1:K ; for kc = 1:K
a = @eval $(Symbol("A$(ka)"))(rand(1:10))
b = @eval $(Symbol("B$(kb)"))(rand(1:10))
c = @eval $(Symbol("C$(kc)"))(rand(1:10))
v3[ka*kb*kc] = f3(a, b, c)
end ; end ; end
println("working on f4")
v4 = fill(0, K*K*K);
for ka = 1:K ; for kb = 1:K ; for kc = 1:K
a = @eval $(Symbol("A$(ka)"))(rand(1:10))
b = @eval $(Symbol("B$(kb)"))(rand(1:10))
c = @eval $(Symbol("C$(kc)"))(rand(1:10))
v4[ka*kb*kc] = f4(a, b, c)
end ; end ; end
end