I have a simple script that solves an economic model. It solves my problem quickly but isn’t very readable. On making a small change to the code (to improve readability), essentially defining the simple function
rhs(i,j,value_func)
leads to an almost four times slower runtime when benchmarked. So I’m obviously doing something stupid. I really want to know how I’ve managed to affect the performance so drastically. Any other advice on how to improve the code, performance or readability is always welcome. I’m afraid economists don’t make good programmers!
using Parameters
using Plots
using BenchmarkTools
Model = @with_kw (α = 0.3,
β = 0.99, #(1.04)^(-1/4),
γ = 1.0,
δ = 0.02,
k_ss = ((1/β + δ-1)*(1/α))^(1/(α-1)),
grid_points = 1000,
k_grid = range(0.9*k_ss, length = grid_points, 1.1*k_ss),
)
function u(c; γ = m.γ)
if c > 0 && γ != 1
return c^(1-γ)/(1-γ)
elseif c > 0 && γ == 1
return log(c)
else
return -Inf
end
end
function rhs(i,j,value_func)
@unpack α, β, δ, k_grid = m
return u(k_grid[i]^α + (1-δ)*k_grid[i] - k_grid[j]) + β * value_func[j];
end
function faster_bellman(m, value_func, policy_func)
@unpack α, β, δ, k_grid = m
value = zeros(length(k_grid))
policy = zeros(Int64, length(k_grid))
js = 1
w = zeros(3)
for i in 1:length(k_grid)
jmin = js;
jmax = length(k_grid);
while (jmax-jmin)>2
jl = floor(Int64, (jmin+jmax)/2); ju = jl+1;
w[1] = u(k_grid[i]^α + (1-δ)*k_grid[i] - k_grid[jl]) + β * value_func[jl]; #rhs(i,jl,value_func)
w[2] = u(k_grid[i]^α + (1-δ)*k_grid[i] - k_grid[ju]) + β * value_func[ju]; #rhs(i,ju,value_func)
if w[2] > w[1]
jmin = jl;
else
jmax = ju;
end
end
w[1] = u(k_grid[i]^α + (1-δ)*k_grid[i] - k_grid[jmin]) + β * value_func[jmin]; #rhs(i,jmin,value_func)
if jmax > jmin;
w[2] = u(k_grid[i]^α + (1-δ)*k_grid[i] - k_grid[jmin+1]) + β * value_func[jmin+1]; #rhs(i,jmin+1,value_func)
else
w[2] = w[1];
end
w[3] = u(k_grid[i]^α + (1-δ)*k_grid[i] - k_grid[jmax]) + β * value_func[jmax]; #rhs(i,jmax,value_func)
loc = findmax(w)[2];
js = loc
value[i] = w[js];
js = jmin + js- 1;
policy[i] = js;
end
return value, policy
end
function solve(m; tol = 1e-6, max_iter = 1500)
@unpack α, β, δ, k_grid = m
value_func = zeros(length(k_grid))
policy_func = zeros(length(k_grid))
error = Inf;
iter = 1
while error > tol && iter <= max_iter
value_func_update, policy_func_update = faster_bellman(m, value_func, policy_func);
error = maximum(abs.(value_func_update - value_func));
value_func = value_func_update;
policy_func = policy_func_update;
iter += 1
println("Iteration $iter, error = $error")
end
return value_func, policy_func
end
m = Model();
@benchmark solve(m)