Hi all,
I have a function g(a,b)
with many methods based on input types of a
and b
. A subset of the methods have common code, eg:
g(a::T1, b::T2) where {T1<:AbstractA,T2<:AbstractB} = (fa(a),fb(b))
where fa
and fb
will specialize on the input type, but clearly g
does not need to specialize on the input type, other than to ensure that a
is a subtype of AbstractA
and b
is a subtype of AbstractB
.
I’d like to tell the compiler to just compile g
once, rather than compiling for every possible combination of subtypes of AbstractA
and AbstractB
. My understanding is that I do this using @nospecialize
. However, I can’t seem to get it to work properly. Specifically the following does work:
g(@nospecialize(a), @nospecialize(b)) = (fa(a), fb(b))
but this does not restrict a
to be a subtype of AbstractA
and b
to be a subtype of AbstractB
. I could also try the following:
g(@nospecialize(a::T1), @nospecialize(b::T2)) where {T1<:AbstractA,T2<:AbstractB} = (fa(a),fb(b))
but based on my MWE below this appears to compile for every combination of subtype of AbstractA
and AbstractB
.
Apologies, there is probably a smarter way to test this but I don’t know what it is. I had to resort to building 10 subtypes of abstract types AbA
, AbB
, and AbC
, and then looping over calls to a trivial function for every combination of subtypes. I did this for four different functions below f1
, f2
, f3
, and f4
, and hoped that the loop over calls to f3
and f4
would be fast. In practice, only the loop over f4
is fast. Any ideas how I can speed up the f3
case so it works like the f4
case, while preserving the subtype restrictions? Code below:
let
K = 10;
abstract type AbA ; end
for ka = 1:K
namesymbol = Symbol("A$(ka)")
@eval struct $(namesymbol) <: AbA ; x::Int ; end
@eval $(Symbol("fa"))(x::$(namesymbol))::Int = x.x + $(ka)
end
abstract type AbB ; end
for kb = 1:K
namesymbol = Symbol("B$(kb)")
@eval struct $(namesymbol) <: AbB ; x::Int ; end
@eval $(Symbol("fb"))(x::$(namesymbol))::Int = x.x + $(kb)
end
abstract type AbC ; end
for kc = 1:K
namesymbol = Symbol("C$(kc)")
@eval struct $(namesymbol) <: AbC ; x::Int ; end
@eval $(Symbol("fc"))(x::$(namesymbol))::Int = x.x + $(kc)
end
function f1(a::Ta, b::Tb, c::Tc) where {Ta<:AbA,Tb<:AbB,Tc<:AbC}
return fa(a) + fb(b) + fc(c)
end
function f2(a, b, c)
return fa(a) + fb(b) + fc(c)
end
function f3(@nospecialize(a::Ta), @nospecialize(b::Tb), @nospecialize(c::Tc))::Float64 where {Ta<:AbA,Tb<:AbB,Tc<:AbC}
return fa(a) + fb(b) + fc(c)
end
function f4(@nospecialize(a), @nospecialize(b), @nospecialize(c))
return fa(a) + fb(b) + fc(c)
end
println("working on f1")
v1 = fill(0, K*K*K);
for ka = 1:K ; for kb = 1:K ; for kc = 1:K
a = @eval $(Symbol("A$(ka)"))(rand(1:10))
b = @eval $(Symbol("B$(kb)"))(rand(1:10))
c = @eval $(Symbol("C$(kc)"))(rand(1:10))
v1[ka*kb*kc] = f1(a, b, c)
end ; end ; end
println("working on f2")
v2 = fill(0, K*K*K);
for ka = 1:K ; for kb = 1:K ; for kc = 1:K
a = @eval $(Symbol("A$(ka)"))(rand(1:10))
b = @eval $(Symbol("B$(kb)"))(rand(1:10))
c = @eval $(Symbol("C$(kc)"))(rand(1:10))
v2[ka*kb*kc] = f2(a, b, c)
end ; end ; end
println("working on f3")
v3 = fill(0, K*K*K);
for ka = 1:K ; for kb = 1:K ; for kc = 1:K
a = @eval $(Symbol("A$(ka)"))(rand(1:10))
b = @eval $(Symbol("B$(kb)"))(rand(1:10))
c = @eval $(Symbol("C$(kc)"))(rand(1:10))
v3[ka*kb*kc] = f3(a, b, c)
end ; end ; end
println("working on f4")
v4 = fill(0, K*K*K);
for ka = 1:K ; for kb = 1:K ; for kc = 1:K
a = @eval $(Symbol("A$(ka)"))(rand(1:10))
b = @eval $(Symbol("B$(kb)"))(rand(1:10))
c = @eval $(Symbol("C$(kc)"))(rand(1:10))
v4[ka*kb*kc] = f4(a, b, c)
end ; end ; end
end