Type inference problem with nested keyword calls

In some code for a structural model, I use keywords to pass around arguments, to avoid mistakes. But it causes type inference to fail. Below is an MWE (sorry, could not make it any smaller, I tried many ways but then the bug disappears).

Should I open an issue, or is there an existing one? Is there a workaround?

EDIT the original example was missing a piece; it also takes ForwardDiff or similar to stress the type inference sufficiently, with Float64 it works fine

Output

MethodInstance for f(::ModelParameters, ::ForwardDiff.Dual{Nothing, Float64, 1})
  from f(model_parameters, x) @ Main /tmp/tmp.jl:71
Arguments
  #self#::Core.Const(f)
  model_parameters::ModelParameters
  x::ForwardDiff.Dual{Nothing, Float64, 1}
Body::NamedTuple{(:ζ̄1, :ζ̄2, :n1, :n2), <:Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}}}
1 ─ %1 = Main.Val(:UU)::Core.Const(Val{:UU}())
│   %2 = (:M, :w1, :w2, :α̂1, :α̂2, :β̂1, :β̂2)::Core.Const((:M, :w1, :w2, :α̂1, :α̂2, :β̂1, :β̂2))
│   %3 = Core.apply_type(Core.NamedTuple, %2)::Core.Const(NamedTuple{(:M, :w1, :w2, :α̂1, :α̂2, :β̂1, :β̂2)})
│   %4 = Core.tuple(x, x, x, x, x, x, x)::NTuple{7, ForwardDiff.Dual{Nothing, Float64, 1}}
│   %5 = (%3)(%4)::@NamedTuple{M::ForwardDiff.Dual{Nothing, Float64, 1}, w1::ForwardDiff.Dual{Nothing, Float64, 1}, w2::ForwardDiff.Dual{Nothing, Float64, 1}, α̂1::ForwardDiff.Dual{Nothing, Float64, 1}, α̂2::ForwardDiff.Dual{Nothing, Float64, 1}, β̂1::ForwardDiff.Dual{Nothing, Float64, 1}, β̂2::ForwardDiff.Dual{Nothing, Float64, 1}}
│   %6 = Core.kwcall(%5, Main._ζ̄nh, model_parameters, %1)::NamedTuple{(:ζ̄1, :ζ̄2, :n1, :n2), <:Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}}}
└──      return %6


julia> versioninfo()
Julia Version 1.10.0-DEV.1340
Commit c99d8393f5* (2023-05-18 19:18 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, tigerlake)
  Threads: 11 on 8 virtual cores

same with

julia> versioninfo()
Julia Version 1.9.0
Commit 8e63055292* (2023-05-07 11:25 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, tigerlake)
  Threads: 8 on 8 virtual cores

MWE

using ForwardDiff

Base.@kwdef struct ModelParameters
    T::Float64
    n0::Float64
    ρ::Float64
    γ1::Float64
    γ2::Float64
end

function optimal_employed_hours(model_parameters::ModelParameters; α::Real, w::Real, M::Real)
    (; T, n0) = model_parameters
    if α ≤ (M + n0 * w) / (M + T * w)
        n0
    else
        α * T - (1 - α) * M / w
    end
end

function employed_utility(model_parameters::ModelParameters; α::Real, w::Real, M::Real, n::Real)
    (; T) = model_parameters
    α * log(M + n * w) + (1 - α) * log(T - n)
end

function unemployed_utility(model_parameters::ModelParameters; α::Real, M::Real, w::Real)
    (; T, ρ) = model_parameters
    α * log(M + ρ * T * w) + (1 - α) * log(T)
end

function calculate_ζ̄(model_parameters::ModelParameters; α::Real, M::Real, w::Real, n::Real)
    𝒰_U = unemployed_utility(model_parameters; α, M, w)
    𝒰_E = employed_utility(model_parameters; α, M, w, n)
    𝒰_E - 𝒰_U
end

function solve_individual_problem(model_parameters::ModelParameters;
                                  α::Real, M::Real, w::Real)
    n = optimal_employed_hours(model_parameters; α, M, w)
    ζ̄ = calculate_ζ̄(model_parameters; α, M, w, n)
    ζ̄, n
end

function thresholds(::Val{:UU}, model_parameters::ModelParameters; M, α1, α2, w1, w2)
    (; ρ, T) = model_parameters
    ζ̄1, n1 = solve_individual_problem(model_parameters; α = α1, M = M + ρ * T * w2, w = w1)
    ζ̄2, n2 = solve_individual_problem(model_parameters; α = α2, M = M + ρ * T * w1, w = w2)
    (ζ̄1, ζ̄2), (zero(n1), zero(n2))
end

adjust_α(; α, β, γ) = α / (α + (1 - α) * (1 - β + β * γ))


function homeprod_thresholds(kind::K, model_parameters::ModelParameters;
                                     M, α1, α2, β1, β2, w1, w2) where K
    (; γ1, γ2) = model_parameters
    thresholds(kind, model_parameters; M, w1, w2,
                      α1 = adjust_α(; α = α1, β = β1, γ = γ1),
                      α2 = adjust_α(; α = α2, β = β2, γ = γ2))
end

function _ζ̄nh(model_parameters, kind::Val{X}; M, w1, w2, α̂1, α̂2, β̂1, β̂2) where X
    (; T, γ1, γ2) = model_parameters
    α1 = log(α̂1)
    α2 = log(α̂2)
    β1 = log(β̂1)
    β2 = log(β̂2)
    (ζ̄1, ζ̄2), (n1, n2) = homeprod_thresholds(kind, model_parameters; M, w1, w2, α1, α2, β1, β2)
    (; ζ̄1, ζ̄2, n1, n2)
end

f(model_parameters, x) = _ζ̄nh(model_parameters, Val(:UU); M = x, w1 = x, w2 = x, α̂1 = x, α̂2 = x, β̂1 = x, β̂2 = x)

model_parameters = ModelParameters(; T = 16.0, n0 = 4.0, ρ = 0.25, γ1 = 0.5, γ2 = 0.5)
x = ForwardDiff.Dual(1.5, 1.5)
@assert f(model_parameters, x) isa NamedTuple # check that code runs
@code_warntype f(model_parameters, x)
1 Like

There seems to be a problem of inferring the return types of functions optimal_employed_hours, employed_utility and unemployed_utility.
If you make it explicit the type warning vanishes:

function optimal_employed_hours(model_parameters::ModelParameters; α::Real, w::Real, M::Real)::Float64
    (; T, n0) = model_parameters
    if α ≤ (M + n0 * w) / (M + T * w)
        n0
    else
        α * T - (1 - α) * M / w
    end
end

function employed_utility(model_parameters::ModelParameters; α::Real, w::Real, M::Real, n::Real)::Float64
    (; T) = model_parameters
    α * log(M + n * w) + (1 - α) * log(T - n)
end

function unemployed_utility(model_parameters::ModelParameters; α::Real, M::Real, w::Real)::Float64
    (; T, ρ) = model_parameters
    α * log(M + ρ * T * w) + (1 - α) * log(T)
end

My guess is, it is because of the other parameters are of type Real. If you change Real to Float64 the type warning vanishes too:

function optimal_employed_hours(model_parameters::ModelParameters; α::Float64, w::Float64, M::Float64)
    (; T, n0) = model_parameters
    if α ≤ (M + n0 * w) / (M + T * w)
        n0
    else
        α * T - (1.0 - α) * M / w
    end
end

function employed_utility(model_parameters::ModelParameters; α::Float64, w::Float64, M::Float64, n::Float64)
    (; T) = model_parameters
    α * log(M + n * w) + (1.0 - α) * log(T - n)
end

function unemployed_utility(model_parameters::ModelParameters; α::Float64, M::Float64, w::Float64)
    (; T, ρ) = model_parameters
    α * log(M + ρ * T * w) + (1.0 - α) * log(T)
end

I would have to do more experiments on this to understand it and to be really able to explain it, so I can’t tell if this a bug or just a limitation of the type inference.

Perhaps this workaround gives enough hints for type experts to have a better explanation than I can do.

Yes, of course explicitly typing calls resolves inference issues, but I need this code working with generic types (think ForwardDiff).

I think this is a bug (also on master). Will wait for feedback here for a while and then open an issue.

EDIT using Cthulhu gives me something I find hard to interpret on Julia master:

   (ζ̄1, ζ̄2), (n1, n2) = homeprod_thresholds(kind, model_parameters; M, w1, w2, α1, α2, β1, β2)
 • (ζ̄1, ζ̄2), (n1, n2) = homeprod_thresholds(kind::Val{:UU}, model_parameters::ModelParameters; M::Float64, w1::Float64, w2::Float64, α1::Float64, α2::Float64, β1::Float64, β2::Float64::NTuple{7, Float64})::Tu…
   (; ζ̄1::Any, ζ̄2::Any, n1::Any, n2::Any::NTuple{4, Any})
   ↩
[ Info: tracking Base
indexed_iterate(t::Tuple, i::Int64, state) @ Base ~/src/julia-git/base/tuple.jl:92
┌ Warning: Some line information is missing, type-assignment may be incomplete
└ @ Cthulhu ~/.julia/packages/Cthulhu/8wpRb/src/codeview.jl:117
92 indexed_iterate(t::Tuple{Any, Any}::Tuple, i::Int64::Int, state::Int64=1) = (@inline; (getfield(t::Tuple{Any, Any}, i::Int64)::Any, (i::Int64+1)::Int64)::Tuple{Any, Int64})
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [w]arn, [h]ide type-stable statements, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
 • ↩
1 Like

Your MWE is not complete: adjust_α is not defined.
This is the reason why it gives the type warning in your MWE :frowning:

1 Like

Thanks! That was a mistake from trying to produce an MWE, too many edits.

I updated the example. Curiously, it works fine with f(1.0), but a ForwardDiff.Dual fails.

1 Like

Ok, I think here is your flaw:

function optimal_employed_hours(model_parameters::ModelParameters; α::Real, w::Real, M::Real)
    (; T, n0) = model_parameters
    if α ≤ (M + n0 * w) / (M + T * w)
        n0 # always Float64
    else
        α * T - (1 - α) * M / w # depends on typef(x), e.g. ForwardDiff.Dual{Float64, Float64, 1}
    end
end

if

julia> typeof(x)
ForwardDiff.Dual{Float64, Float64, 1}

optimal_employed_hours returns either a
Float64 in the case of n0
or a
ForwardDiff.Dual{Nothing, Float64, 1} in the else branch.

If

julia> typeof(x)
Float64

the return value of optimal_employed_hours is a Float64 in both branches.

1 Like

Thanks! I am curious how you found this, just eyeballing or some type inference tool? Cthulhu could not traverse keyword arguments for me. (This is all part of a larger codebase, the MWE above is much simplified).

1 Like

I traversed manually through the hierarchy of functions, setting each parameter on the REPL before calling and checking with @code_warntype to be sure. But the if else end was my suspect from the beginning as it is quite a prototype for type ambiguity. Your first MWE did confuse me a bit, because I needed to make the two other functions explicit too, which I didn’t found out why, until I found that adjust_α wasn’t defined :wink:

I’m not so good in interpreting the output of the available tools like @code_warntype correctly.

1 Like

I just learned that

exists.

1 Like

This is the output with your MWE:

julia> @report_opt f(model_parameters, x)
═════ 2 possible errors found ═════
┌ @ REPL[28]:1 Core.kwcall(NamedTuple{(:M, :w1, :w2, :α̂1, :α̂2, :β̂1, :β̂2)}(tuple(x, x, x, x, x, x, x)), _ζ̄nh, model_parameters, Val(:UU))
│┌ @ REPL[27]:1 #_ζ̄nh#20(M, w1, w2, α̂1, α̂2, β̂1, β̂2, _3, model_parameters, kind)
││┌ @ REPL[27]:7 Core.kwcall(NamedTuple{(:M, :w1, :w2, :α1, :α2, :β1, :β2)}(tuple(M, w1, w2, α1, α2, β1, β2)), homeprod_thresholds, kind, model_parameters)
│││┌ @ REPL[26]:1 #homeprod_thresholds#19(M, α1, α2, β1, β2, w1, w2, _3, kind, model_parameters)
││││┌ @ REPL[26]:4 Core.kwcall(NamedTuple{(:M, :w1, :w2, :α1, :α2)}(tuple(M, w1, w2, Core.kwcall(NamedTuple{(:α, :β, :γ)}(tuple(α1, β1, γ1)), adjust_α), Core.kwcall(NamedTuple{(:α, :β, :γ)}(tuple(α2, β2, γ2)), adjust_α))), thresholds, kind, model_parameters)
│││││┌ @ REPL[24]:1 #thresholds#17(M, α1, α2, w1, w2, _3, _4, model_parameters)
││││││┌ @ REPL[24]:3 Core.kwcall(NamedTuple{(:α, :M, :w)}(tuple(α1, M + *(ρ, T, w2), w1)), solve_individual_problem, model_parameters)
│││││││┌ @ REPL[23]:1 #solve_individual_problem#16(α, M, w, _3, model_parameters)
││││││││┌ @ REPL[23]:4 (%3)
│││││││││ runtime dispatch detected: ::NamedTuple{(:α, :M, :w, :n)}(%3::Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}})::NamedTuple{(:α, :M, :w, :n), _A} where _A<:Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}}
││││││││└──────────────
││┌ @ REPL[27]:8 (%115)
│││ runtime dispatch detected: ::NamedTuple{(:ζ̄1, :ζ̄2, :n1, :n2)}(%115::Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}})::NamedTuple{(:ζ̄1, :ζ̄2, :n1, :n2), _A} where _A<:Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}}
││└──────────────

In this case output is not much better as what we already have.

1 Like

I have been using JET.jl for packages, but it didn’t occur to me to use it for this problem. Thanks!

1 Like