Greetings:
I am working on solving a heterogeneous-agent economics model in Julia. The most demanding part is solving the household problem using dynamic programming. Households can be either employed or unemployed. The state space of unemployed households consists of just their assets. For employed households, the state space consists of productivity and the wage submarket in addition to their asset position. Productivity evolves according to a Markov process between periods of unemployment whereas the wage submarket is fixed for the duration of the job.
Both unemployed and employed households make a savings decision. However, unemployed households also which wage submarket to search for. The tradeoff to searching in a higher wage submarket is a lower job finding rate.
Since individuals can transition between employed and unemployed, the value function of being employed involves that of being unemployed and vice versa. Accordingly, we iterate on both value functions simultaneously.
The demanding part of the algorithm is the use of numerical maximization for each point of the state space. The key function which implements the dynamic programming is solve_household
which calls on the functions update_value
and update_value_un
. The problem is that, currently, it takes several minutes to obtain convergence of the household problem. Solving general equilibrium implies solving the household problem several times (at different possible interest rates), so this is not good enough.
I checked type stability using @code_warntype
, which seems to be fine. But I suspect I am nevertheless suspect I am committing some major performance goof as comparableβat least as far as I can tellβFortran code runs much faster.
I attach self-contained code below. The initial part of the code consists of an initialization and solving the firm problem, which is irrelevant to the performance issues in the household problem. I welcome all suggestions that would improve performance.
Many thanks.
using PyPlot, BenchmarkTools
using PyFormattedStrings
using LaTeXStrings
using Parameters, CSV, Random
using Dierckx, Distributions, ArgParse
using Optim, Interpolations
function grid_Cons_Grow(n, left, right, g)
"""
Creates n+1 gridpoints with growing distance on interval [a, b]
according to the formula
x_i = a + (b-a)/((1+g)^n-1)*((1+g)^i-1) for i=0,1,...n
"""
x = zeros(n)
for i in 0:(n-1)
x[i+1] = @. left + (right-left)/((1+g)^n-1)*((1+g)^i-1)
end
return x
end
Para = @with_kw ( Ξ³=0.5,
egam= 1.0-1.0/Ξ³,
Ξ² = 0.96,
Ξ± = 0.2,
A_p = 1.0,
Ξ΄= 0.0,
Ο = 0.9,
Ο_eps = 0.04*(1.0-Ο^2),
NS= 9,
a_l= 0.0,
a_u= 50.0,
NA=50,
a = grid_Cons_Grow(NA, a_l, a_u, 0.01),
mc= rouwenhorst(NS, Ο, Ο_eps^(1/2), 0),
P = mc.p,
z = exp.(mc.state_values),
Nomega = 14,
omega_grid = range(0.1, stop=0.9, length=Nomega),
a_bor =-0.0,
Tr = 0.02,
h_un = 0.02,
Pi = stationary_distributions(mc)[1],
sep = 0.05,
ΞΊ = 2.0,
A_m = 0.54,
Ξ·_L = 0.72
)
# Create asset grid
#u::Function = (c, Ξ³=Ξ³) -> c^(1-Ξ³)/(1-Ξ³)
#u_prime::Function = c -> c^(-Ξ³)
#u_prime_inv::Function = c -> c^(-1/Ξ³)
function matching_functions(para)
@unpack A_m, Ξ·_L = para
# Matching functions
jf(ΞΈ) = min(A_m*ΞΈ^(1-Ξ·_L), 1.0)
q(ΞΈ) = min(A_m*ΞΈ^Ξ·_L, 1.0)
return jf, q
end
function initialize(para)
@unpack Ξ³, egam, Ξ±, Ξ², A_p, Ξ΄, Ο, NS, NA, Nomega, z, a, a_bor, Tr, h_un = para
minrate = -Ξ΄
maxrate = (1-Ξ²)/Ξ²
r = 0.8*(minrate+maxrate)
k_opt = @. ((r+Ξ΄)/(z*A_p*Ξ±))^(Ξ±-1)
c_pol = zeros(NA, NS, Nomega)
for is in 1:NS
for iom in 1:Nomega
c_pol[:, is, iom] = @. max(r*(a + a_bor),1e-10)
end
end
V = c_pol.^(egam)/egam
c_pol_un = @. max(r*(a+a_bor), 1e-10)
V_un = c_pol_un.^(egam)/egam
# Initial guess of distribution function
phi = 1.0/(NA*NS*Nomega*2)
return r, k_opt, c_pol, c_pol_un, V, V_un
end
function bellman_firm(r, para;tol=1e-7)
#=
Value of firm at productivity grid (is, iom)
=#
@unpack omega_grid, z, NS, Nomega, Ξ΄, Ξ±, A_p, sep, P = para
J_new = zeros(NS, Nomega)
diff = 1.0
J = zero(J_new)
profit = zero(J_new)
while diff > tol
for (iom, Ο) in enumerate(omega_grid)
for (is, z_i) in enumerate(z)
k_opt = ((r+Ξ΄)/(z_i*A_p*Ξ±))^(Ξ±-1)
profit[is, iom] = (1-Ο)* (z_i*A_p*k_opt^Ξ±-(r+Ξ΄)*k_opt)
for is_p in 1:NS
J[is, iom] += (profit[is, iom] + (1.0-sep)/(1+r)*J_new[is_p, iom])*P[is, is_p]
end
end
end
diff = maximum(abs.(J-J_new))
#print(diff)
J_new .= J
end
return J, profit
end
function ΞΈ_fun(J, r, para)
# Returns mapping between Ο and ΞΈ
@unpack NS, Nomega, A_p, ΞΊ, Ξ·_L, Pi = para
# Expected firm value from stationary distribution
EJ = zeros(Nomega)
for iom in 1:Nomega
EJ[iom] = sum(J[:, iom].*Pi)
end
EJ .= @. max(EJ, 1e-10)
ΞΈ = @. (A_p*EJ/(ΞΊ*(1+r)))^(1/Ξ·_L)
return ΞΈ
end
#@code_warntype(ΞΈ_fun(J, 0.02, para))
#@btime ΞΈ_fun(J, 0.02, para)
function bellman(a_prime, ia, is, iom, V_fun, V_un_fun, r, para)
# Bellman equation of employed HH with state (a,z,Ο) and guess of
# savings a_prime
@unpack Ξ΄, Ξ², A_p, Ξ±, Tr, a, omega_grid, P, sep, egam, z, a_bor, NS =para
a_i, z_i, Ο = a[ia], z[is], omega_grid[iom]
k_opt = ((r+Ξ΄)/(z_i*A_p*Ξ±))^(Ξ±-1)
revenue = z_i*A_p*k_opt^Ξ±-(r+Ξ΄)*k_opt
wage = Ο*revenue
c = max((a_i + a_bor)*(1+r)+wage-(a_prime+a_bor)-Tr, 1e-10)
expect = Ξ²*sep*max(V_un_fun(a_prime)^(egam)/egam, 1e-10)
for is_p in 1:NS
expect += Ξ²*(1-sep)*max(V_fun(a_prime, is_p, iom), 1e-10)^(egam)/egam*P[is, is_p]
end
util = (c^egam/egam + expect)
return util
end
#@code_warntype(bellman(5.0, 3, 2, 5, V_fun, V_un_fun, 0.02, para))
#@btime (bellman(5.0, 3, 2, 5, $V_fun, $V_un_fun, 0.02, $para))
function bellman_un(a_prime, ia, iom, V_fun, V_un_fun, jf_val, r, para)
# Bellman equation of unemployed HH with state (a) and guess of
# savings a_prime
@unpack Ξ΄, Ξ², A_p, Ξ±, Tr, a, omega_grid, Pi, h_un, a_bor, NS, egam =para
jf = jf_val[iom]
a_i = a[ia]
# consumption
c = max(((a_i+a_bor)*(1+r)-(a_prime+a_bor)+h_un), 1e-10)
expect = Ξ²*(1.0-jf)*max(V_un_fun(a_prime), 1e-10)^(egam)/egam
@inbounds for is_p in 1:NS
expect += Ξ²*jf*max(V_fun(a_prime, is_p, iom), 1e-10)^(egam)/egam*Pi[is_p]
end
util = c^egam/egam + expect
return util
end
function update_value_un(V_fun, V_un_fun, jf_vals, r, para)
# Unemployed
@unpack NA, NS, Nomega, a, omega_grid, a_l, a_u, a_bor, A_p, h_un, z, Ξ΄, Ξ±, Tr, egam = para
a_pol_un = zeros(NA)
c_pol_un = zeros(NA)
V_un_new = zeros(NA)
omega_policy = zeros(Int, NA)
temp_fun = zeros(NA, Nomega)
temp_maximizer = zeros(NA, Nomega)
@inbounds Threads.@threads for ia in 1:NA
a_i = a[ia]
for iom = 1:Nomega
# get a_prime
#a_u = max((a[ia] + a_bor)*(1+r) - c_pol_un[ia] + h_un, 1e-10)
a_u = max((a_i + a_bor)*(1+r) + h_un, 1e-10)
objective = x -> bellman_un(x, ia, iom, V_fun, V_un_fun, jf_vals, r, para)
sol = maximize(objective, a_l, a_u)
temp_fun[ia, iom] = Optim.maximum(sol)
temp_maximizer[ia, iom] = Optim.maximizer(sol)
end
# Ο which maximizes value
iom_max = argmax(@view(temp_fun[ia, :]))
# policies and value function
a_pol_un[ia] = temp_maximizer[ia, iom_max]
c_pol_un[ia] =max((a_i+a_bor)*(1+r) +h_un - (a_pol_un[ia] +a_bor) , (1e-10))
omega_policy[ia] =iom_max
V_un_new[ia] = temp_fun[ia, iom_max]
end
return V_un_new, a_pol_un, c_pol_un, omega_policy
end
function update_value(V_fun, V_un_fun, r, para)
@unpack NA, NS, Nomega, a, omega_grid, a_l, a_u, a_bor, A_p, h_un, z, Ξ΄, Ξ±, Tr, egam = para
a_pol = zeros(NA, NS, Nomega)
c_pol = zeros(NA, NS, Nomega)
V_new = zeros(NA, NS, Nomega)
k_opt = @. ((r+Ξ΄)/(z*A_p*Ξ±))^(Ξ±-1)
# Interpolate employed value function
@inbounds Threads.@threads for ia in 1:NA
a_i = a[ia]
for iom in 1:Nomega
for is in 1:NS
z_i = z[is]
wage = omega_grid[iom]*(z_i*A_p*k_opt[is]^Ξ±-(r+Ξ΄)*k_opt[is])
#x_in = max((a[ia] + a_bor)*(1+r) + wage - c_pol[ia, is, iom] - Tr, 1e-10)
a_u = max((a_i + a_bor)*(1+r) + wage - Tr, 1e-10)
objective = x -> bellman(x, ia, is, iom, V_fun, V_un_fun, r, para)
#sol = maximize(objective, a_l, a_u)
sol = maximize(objective, a_l, a_u)
x_in = Optim.maximizer(sol)
a_pol[ia, is, iom] = max(x_in, 0.0)
c_pol[ia, is, iom] = max((a_i+a_bor)*(1+r)+wage-(
a_pol[ia, is, iom] + a_bor) - Tr, 1e-10)
V_new[ia, is, iom] = Optim.maximum(sol)
end
end
end
return V_new, a_pol, c_pol
end
function solve_household(V, V_un, jf_vals, r, para; itermax=5000, error_tol=1e-9, N=10)
@unpack NA, NS, Nomega, a, omega_grid, a_l, a_u, a_bor, A_p, h_un, z, Ξ΄, Ξ±, Tr, egam = para
#initialize policies
a_pol_un = zeros(NA)
c_pol_un = zeros(NA)
V_un_new = similar(V_un)
omega_policy = zeros(Int, NA)
a_pol = zeros(NA, NS, Nomega)
c_pol = zeros(NA, NS, Nomega)
V_new = similar(V)
# Outer loop for value function iteration
err = 1.0
iter_num = 1
#V_fun(x, is, iom) = LinearInterpolation(a, egam.*V[:, is, iom].^(1/egam), extrapolation_bc=Line())(x)
#V_un_fun(x) = LinearInterpolation(a, egam.*V_un.^(1/egam), extrapolation_bc=Line())(x)
while err > error_tol && iter_num < itermax
V_un_fun(x) = LinearInterpolation(a, egam.*V_un.^(1/egam), extrapolation_bc=Line())(x)
V_fun(x, is, iom) = LinearInterpolation(a, egam.*V[:, is, iom].^(1/egam), extrapolation_bc=Line())(x)
# Optimal decision for every gridpoint
# given the interest rate
V_un_new, a_pol_un, c_pol_un, omega_policy = update_value_un(V_fun, V_un_fun, jf_vals, r, para)
V_new, a_pol, c_pol = update_value(V_fun, V_un_fun, r, para)
# update value function for fixed policies
for n in 1:N
V_un_fun(x) = LinearInterpolation(a, egam.*V_un_new.^(1/egam), extrapolation_bc=Line())(x)
#V_fun(x, is, iom) = LinearInterpolation(a, egam.*V_new[:, is, iom].^(1/egam), extrapolation_bc=Line())(x)
for ia in 1:NA
x_u = a_pol_un[ia]
iom_max = omega_policy[ia]
V_un_new[ia] = bellman_un(x_u, ia, iom_max, V_fun, V_un_fun, jf_vals, r, para)
for is in 1:NS
for iom in Nomega
x_e = a_pol[ia, is, iom]
V_new[ia, is, iom] = bellman(x_e, ia, is, iom, V_fun, V_un_fun, r, para)
end
end
end
end
err_un = maximum(abs.(V_un_new - V_un)./maximum(abs.(V_un)))
@show err_un
err_e = maximum(abs.(V_new - V)./maximum(abs.(V)))
@show err_e
err = err_un + err_e
# update unemployed value function
V_un .= V_un_new
V .= V_new
# Interpolate value functions (transformed)
@show iter_num +=1
end
return V, V_un, c_pol_un, c_pol, a_pol_un, a_pol, omega_policy
end
# initialize
const para = Para(NA=50)
@unpack a, z, omega_grid = para
r, k_opt, c_pol, c_pol_un, V, V_un = initialize(para)
# Firm value by value function iteration: (NS, Nomega)
J, profits = bellman_firm(r, para)
# Tightness function: length Nomega
ΞΈ_vals = ΞΈ_fun(J, r, para)
# Job finding rate
jf = matching_functions(para)[1]
jf_vals = jf.(ΞΈ_vals)
# Solve HH problem: value and policy functions
@code_warntype solve_household(V, V_un, jf_vals, r, para)
V, V_un, c_pol_un, c_pol, a_pol_un, a_pol, omega_policy = solve_household(V, V_un, jf_vals, r, para)
# Assign bins to savings policies