Type instability issue with parameterizable struct and tuples

Here is a MWE:

struct Foo{T}
    buffer::Vector{Complex{T}}
end

function Foo(T, num_samples)
    Foo(Vector{Complex{T}}(undef, num_samples))
end

function foos(T, states, num_samples)
    map(y -> map(x -> Foo(T, num_samples), y), states)
end
function foos(states, num_samples)
    foos(Float32, states, num_samples)
end
julia> @code_warntype foos(([1],), 10)
MethodInstance for foos(::Tuple{Vector{Int64}}, ::Int64)
  from foos(states, num_samples::Integer) in Main at REPL[4]:1
Arguments
  #self#::Core.Const(foos)
  states::Tuple{Vector{Int64}}
  num_samples::Int64
Body::Tuple{Vector} # This is red
1 ─ %1 = Main.foos(Main.Float32, states, num_samples)::Tuple{Vector}
└──      return %1

It looks like it can not infer the complete Tuple.
I used Cthulhu to find the root cause of this and found out that the DataType is the problem.
If I make the following change

function foos(T, states, num_samples)
    map(y -> map(x -> Foo(Float32, num_samples), y), states) # previously: map(y -> map(x -> Foo(T, num_samples), y), states)
end

everything is type stable:

julia> @code_warntype foos(([1],), 10)
MethodInstance for foos(::Tuple{Vector{Int64}}, ::Int64)
  from foos(states, num_samples::Integer) in Main at REPL[4]:1
Arguments
  #self#::Core.Const(foos)
  states::Tuple{Vector{Int64}}
  num_samples::Int64
Body::Tuple{Vector{Foo{Float32}}}
1 ─ %1 = Main.foos(Main.Float32, states, num_samples)::Tuple{Vector{Foo{Float32}}}
└──      return %1

How can I fix this type instability without hard coding T to be Float32?

julia> versioninfo()
Julia Version 1.8.5
Commit 17cfb8e65ea (2023-01-08 06:45 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 16 × AMD Ryzen 7 PRO 5850U with Radeon Graphics
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, znver3)
  Threads: 1 on 16 virtual cores

do it this way

struct Foo{T}
    buffer::Vector{Complex{T}}
end

 function Foo(::Type{T}, num_samples) where {T}
     Foo(Vector{Complex{T}}(undef, num_samples))
 end

function foos(::Type{T}, states, num_samples) where {T}
    map(y -> map(x -> Foo(T, num_samples), y), states)
end

function foos(states, num_samples)
    foos(Float32, states, num_samples)
end

then

julia> @code_warntype foos(([1],), 10)
MethodInstance for foos(::Tuple{Vector{Int64}}, ::Int64)
  from foos(states, num_samples) @ Main REPL[9]:1
Arguments
  #self#::Core.Const(foos)
  states::Tuple{Vector{Int64}}
  num_samples::Int64
Body::Tuple{Vector{Foo{Float32}}}
1 ─ %1 = Main.foos(Main.Float32, states, num_samples)::Tuple{Vector{Foo{Float32}}}
└──      return %1
2 Likes