I’m trying to use the type of a NamedTuple to statically compile some parts of functions. I’ve found some unexpected behavior that I don’t really understand: When I use a NamedTuple with only one field it doesn’t seem to want to do constant propagation.
Im using Julia 1.10.5
Function get_field_type(::Type{T}, field::Symbol) where T <: NamedTuple
names = fieldnames(T)
types = fieldtypes(T)
index = findfirst(==(field), names)
return index === nothing ? throw(ArgumentError("Field $field not found")) : types[index]
end
function get_field_type(@specialize(nt::NamedTuple), field::Symbol)
return get_field_type(typeof(nt), field)
end
function nttest(a, b , @specialize(nt))
if get_field_type(nt, :a) == Int64
return a+b
elseif get_field_type(nt, :a) == Float64
return a*b
end
end
function test1()
nt = (;a = 1)
@code_warntype nttest(1, 2, nt)
end
function test2()
nt = (;a = 1, b = 2)
@code_warntype nttest(1, 2, nt)
end
In this example for test1(), there is an if statement in the function, but for test2() it’s correctly compiled away. Why is this?
julia> test1()
MethodInstance for nttest(::Int64, ::Int64, ::@NamedTuple{a::Int64})
from nttest(a, b, nt) @ Main ~/Library/Mobile Documents/com~apple~CloudDocs/Documents/PhD/JuliaProjects/InteractiveIsing/Tests/NT.jl:24
Arguments
#self#::Core.Const(nttest)
a::Int64
b::Int64
nt::@NamedTuple{a::Int64}
Body::Union{Nothing, Int64}
1 ─ nothing
│ %2 = Main.get_field_type(nt, :a)::DataType
│ %3 = (%2 == Main.Int64)::Bool
└── goto #3 if not %3
2 ─ %5 = (a + b)::Int64
└── return %5
3 ─ %7 = Main.get_field_type(nt, :a)::DataType
│ %8 = (%7 == Main.Float64)::Bool
└── goto #5 if not %8
4 ─ %10 = (a * b)::Int64
└── return %10
5 ─ return nothing
julia> test2()
MethodInstance for nttest(::Int64, ::Int64, ::@NamedTuple{a::Int64, b::Int64})
from nttest(a, b, nt) @ Main ~/Library/Mobile Documents/com~apple~CloudDocs/Documents/PhD/JuliaProjects/InteractiveIsing/Tests/NT.jl:24
Arguments
#self#::Core.Const(nttest)
a::Int64
b::Int64
nt::@NamedTuple{a::Int64, b::Int64}
Body::Int64
1 ─ nothing
│ %2 = Main.get_field_type(nt, :a)::Core.Const(Int64)
│ %3 = (%2 == Main.Int64)::Core.Const(true)
└── goto #3 if not %3
2 ─ %5 = (a + b)::Int64
└── return %5
3 ─ Core.Const(:(Main.get_field_type(nt, :a)))
│ Core.Const(:(%7 == Main.Float64))
│ Core.Const(:(goto %12 if not %8))
│ Core.Const(:(a * b))
│ Core.Const(:(return %10))
└── Core.Const(:(return nothing))