In some code for a structural model, I use keywords to pass around arguments, to avoid mistakes. But it causes type inference to fail. Below is an MWE (sorry, could not make it any smaller, I tried many ways but then the bug disappears).
Should I open an issue, or is there an existing one? Is there a workaround?
EDIT the original example was missing a piece; it also takes ForwardDiff or similar to stress the type inference sufficiently, with Float64
it works fine
Output
MethodInstance for f(::ModelParameters, ::ForwardDiff.Dual{Nothing, Float64, 1})
from f(model_parameters, x) @ Main /tmp/tmp.jl:71
Arguments
#self#::Core.Const(f)
model_parameters::ModelParameters
x::ForwardDiff.Dual{Nothing, Float64, 1}
Body::NamedTuple{(:ζ̄1, :ζ̄2, :n1, :n2), <:Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}}}
1 ─ %1 = Main.Val(:UU)::Core.Const(Val{:UU}())
│ %2 = (:M, :w1, :w2, :α̂1, :α̂2, :β̂1, :β̂2)::Core.Const((:M, :w1, :w2, :α̂1, :α̂2, :β̂1, :β̂2))
│ %3 = Core.apply_type(Core.NamedTuple, %2)::Core.Const(NamedTuple{(:M, :w1, :w2, :α̂1, :α̂2, :β̂1, :β̂2)})
│ %4 = Core.tuple(x, x, x, x, x, x, x)::NTuple{7, ForwardDiff.Dual{Nothing, Float64, 1}}
│ %5 = (%3)(%4)::@NamedTuple{M::ForwardDiff.Dual{Nothing, Float64, 1}, w1::ForwardDiff.Dual{Nothing, Float64, 1}, w2::ForwardDiff.Dual{Nothing, Float64, 1}, α̂1::ForwardDiff.Dual{Nothing, Float64, 1}, α̂2::ForwardDiff.Dual{Nothing, Float64, 1}, β̂1::ForwardDiff.Dual{Nothing, Float64, 1}, β̂2::ForwardDiff.Dual{Nothing, Float64, 1}}
│ %6 = Core.kwcall(%5, Main._ζ̄nh, model_parameters, %1)::NamedTuple{(:ζ̄1, :ζ̄2, :n1, :n2), <:Tuple{ForwardDiff.Dual{Nothing, Float64, 1}, ForwardDiff.Dual{Nothing, Float64, 1}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}, Union{Float64, ForwardDiff.Dual{Nothing, Float64, 1}}}}
└── return %6
julia> versioninfo()
Julia Version 1.10.0-DEV.1340
Commit c99d8393f5* (2023-05-18 19:18 UTC)
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 8 × 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, tigerlake)
Threads: 11 on 8 virtual cores
same with
julia> versioninfo()
Julia Version 1.9.0
Commit 8e63055292* (2023-05-07 11:25 UTC)
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 8 × 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-14.0.6 (ORCJIT, tigerlake)
Threads: 8 on 8 virtual cores
MWE
using ForwardDiff
Base.@kwdef struct ModelParameters
T::Float64
n0::Float64
ρ::Float64
γ1::Float64
γ2::Float64
end
function optimal_employed_hours(model_parameters::ModelParameters; α::Real, w::Real, M::Real)
(; T, n0) = model_parameters
if α ≤ (M + n0 * w) / (M + T * w)
n0
else
α * T - (1 - α) * M / w
end
end
function employed_utility(model_parameters::ModelParameters; α::Real, w::Real, M::Real, n::Real)
(; T) = model_parameters
α * log(M + n * w) + (1 - α) * log(T - n)
end
function unemployed_utility(model_parameters::ModelParameters; α::Real, M::Real, w::Real)
(; T, ρ) = model_parameters
α * log(M + ρ * T * w) + (1 - α) * log(T)
end
function calculate_ζ̄(model_parameters::ModelParameters; α::Real, M::Real, w::Real, n::Real)
𝒰_U = unemployed_utility(model_parameters; α, M, w)
𝒰_E = employed_utility(model_parameters; α, M, w, n)
𝒰_E - 𝒰_U
end
function solve_individual_problem(model_parameters::ModelParameters;
α::Real, M::Real, w::Real)
n = optimal_employed_hours(model_parameters; α, M, w)
ζ̄ = calculate_ζ̄(model_parameters; α, M, w, n)
ζ̄, n
end
function thresholds(::Val{:UU}, model_parameters::ModelParameters; M, α1, α2, w1, w2)
(; ρ, T) = model_parameters
ζ̄1, n1 = solve_individual_problem(model_parameters; α = α1, M = M + ρ * T * w2, w = w1)
ζ̄2, n2 = solve_individual_problem(model_parameters; α = α2, M = M + ρ * T * w1, w = w2)
(ζ̄1, ζ̄2), (zero(n1), zero(n2))
end
adjust_α(; α, β, γ) = α / (α + (1 - α) * (1 - β + β * γ))
function homeprod_thresholds(kind::K, model_parameters::ModelParameters;
M, α1, α2, β1, β2, w1, w2) where K
(; γ1, γ2) = model_parameters
thresholds(kind, model_parameters; M, w1, w2,
α1 = adjust_α(; α = α1, β = β1, γ = γ1),
α2 = adjust_α(; α = α2, β = β2, γ = γ2))
end
function _ζ̄nh(model_parameters, kind::Val{X}; M, w1, w2, α̂1, α̂2, β̂1, β̂2) where X
(; T, γ1, γ2) = model_parameters
α1 = log(α̂1)
α2 = log(α̂2)
β1 = log(β̂1)
β2 = log(β̂2)
(ζ̄1, ζ̄2), (n1, n2) = homeprod_thresholds(kind, model_parameters; M, w1, w2, α1, α2, β1, β2)
(; ζ̄1, ζ̄2, n1, n2)
end
f(model_parameters, x) = _ζ̄nh(model_parameters, Val(:UU); M = x, w1 = x, w2 = x, α̂1 = x, α̂2 = x, β̂1 = x, β̂2 = x)
model_parameters = ModelParameters(; T = 16.0, n0 = 4.0, ρ = 0.25, γ1 = 0.5, γ2 = 0.5)
x = ForwardDiff.Dual(1.5, 1.5)
@assert f(model_parameters, x) isa NamedTuple # check that code runs
@code_warntype f(model_parameters, x)