Constant propagation with a namedtuple with one field

I’m trying to use the type of a NamedTuple to statically compile some parts of functions. I’ve found some unexpected behavior that I don’t really understand: When I use a NamedTuple with only one field it doesn’t seem to want to do constant propagation.

Im using Julia 1.10.5

Function get_field_type(::Type{T}, field::Symbol) where T <: NamedTuple
    names = fieldnames(T)
    types = fieldtypes(T)
    
    index = findfirst(==(field), names)
    
    return index === nothing ? throw(ArgumentError("Field $field not found")) : types[index]
end

function get_field_type(@specialize(nt::NamedTuple), field::Symbol)
    return get_field_type(typeof(nt), field)
end

function nttest(a, b , @specialize(nt))
    if get_field_type(nt, :a) == Int64
        return a+b
    elseif get_field_type(nt, :a) == Float64
        return a*b
    end
end
function test1()
    nt = (;a = 1)
    @code_warntype nttest(1, 2, nt)
end

function test2()
    nt = (;a = 1, b = 2)
    @code_warntype nttest(1, 2, nt)
end

In this example for test1(), there is an if statement in the function, but for test2() it’s correctly compiled away. Why is this?

julia> test1()
MethodInstance for nttest(::Int64, ::Int64, ::@NamedTuple{a::Int64})
  from nttest(a, b, nt) @ Main ~/Library/Mobile Documents/com~apple~CloudDocs/Documents/PhD/JuliaProjects/InteractiveIsing/Tests/NT.jl:24
Arguments
  #self#::Core.Const(nttest)
  a::Int64
  b::Int64
  nt::@NamedTuple{a::Int64}
Body::Union{Nothing, Int64}
1 ─       nothing
│   %2  = Main.get_field_type(nt, :a)::DataType
│   %3  = (%2 == Main.Int64)::Bool
└──       goto #3 if not %3
2 ─ %5  = (a + b)::Int64
└──       return %5
3 ─ %7  = Main.get_field_type(nt, :a)::DataType
│   %8  = (%7 == Main.Float64)::Bool
└──       goto #5 if not %8
4 ─ %10 = (a * b)::Int64
└──       return %10
5 ─       return nothing

julia> test2()
MethodInstance for nttest(::Int64, ::Int64, ::@NamedTuple{a::Int64, b::Int64})
  from nttest(a, b, nt) @ Main ~/Library/Mobile Documents/com~apple~CloudDocs/Documents/PhD/JuliaProjects/InteractiveIsing/Tests/NT.jl:24
Arguments
  #self#::Core.Const(nttest)
  a::Int64
  b::Int64
  nt::@NamedTuple{a::Int64, b::Int64}
Body::Int64
1 ─      nothing
│   %2 = Main.get_field_type(nt, :a)::Core.Const(Int64)
│   %3 = (%2 == Main.Int64)::Core.Const(true)
└──      goto #3 if not %3
2 ─ %5 = (a + b)::Int64
└──      return %5
3 ─      Core.Const(:(Main.get_field_type(nt, :a)))
│        Core.Const(:(%7 == Main.Float64))
│        Core.Const(:(goto %12 if not %8))
│        Core.Const(:(a * b))
│        Core.Const(:(return %10))
└──      Core.Const(:(return nothing))
1 Like

Okay I’m guessing this is a bug.

If I add a println statement to the function get_field_type then suddenly it compiles correctly.

Function get_field_type(::Type{T}, field::Symbol) where T <: NamedTuple
    println("Bla")
    names = fieldnames(T)
    types = fieldtypes(T)
    
    index = findfirst(==(field), names)
    
    return index === nothing ? throw(ArgumentError("Field $field not found")) : types[index]
end

Works fine.

1 Like

Wrapping the symbol in a Val type also works. Still, I don’t know why this is necessary and what makes NamedTuples of a single field work different. Any help would be appreciated.

get_field_type(nt, field::Symbol) = get_field_type(nt, Val{field}())
function get_field_type(::Type{T}, field::Val{S}) where {T <: NamedTuple,S}
    names = fieldnames(T)
    types = fieldtypes(T)
    
    index = findfirst(==(S), names)
    
    return index === nothing ? throw(ArgumentError("Field $field not found")) : types[index]
end

Opened an issue:

Do you need a workaround or something?

2 Likes

It seems workaround with the Val type works for now, so it’s not a huge deal for me, assuming it will work otherwise. Still, I don’t understand why a println statement makes it suddenly compile correctly.

Reproducible on 1.10.6 and 1.11.1 too

1 Like

Fixed already!

1 Like