Seemingly innoculus change causes code to run 4 times slower

I think the problem is that you need to pass your model m to rhs. Otherwise it is treated as a global variable. The default keyword in u might be problematic too.

using Parameters
using Plots
using BenchmarkTools



Model = @with_kw (α = 0.3,
                  β = 0.99, #(1.04)^(-1/4),
                  γ = 1.0,
                  δ = 0.02,
                  k_ss = ((1/β + δ-1)*(1/α))^(1/(α-1)),
                  grid_points = 1000,
                  k_grid = range(0.9*k_ss, length = grid_points, 1.1*k_ss),
                  )



function u(c; γ = m.γ)
    if c > 0 && γ != 1
        return c^(1-γ)/(1-γ)
    elseif c > 0 && γ == 1
        return log(c)
    else
        return -Inf
    end
end

function rhs(m, i,j,value_func)
    @unpack α, β, δ, k_grid = m 
    return u(k_grid[i]^α + (1-δ)*k_grid[i] - k_grid[j]) + β * value_func[j];
end

function faster_bellman(m, value_func, policy_func)
    @unpack α, β, δ, k_grid = m

    
    value = zeros(length(k_grid))
    policy = zeros(Int64, length(k_grid))

    js = 1
    w = zeros(3)
    
    for i in 1:length(k_grid)
        jmin = js;
        jmax = length(k_grid);

        while (jmax-jmin)>2       
            jl = floor(Int64, (jmin+jmax)/2); ju = jl+1;
            w[1] = rhs(m, i,jl,value_func)
            w[2] = rhs(m, i,ju,value_func)
            if w[2] > w[1]
                jmin = jl;
            else
                jmax = ju;
            end
        end
        w[1] = rhs(m, i,jmin,value_func)
        if jmax > jmin;
            w[2] = rhs(m, i,jmin+1,value_func)
        else
            w[2] = w[1];
        end
        w[3] = rhs(m, i,jmax, value_func)
        loc = findmax(w)[2];
        js = loc
        value[i] = w[js];
        js = jmin + js- 1;
        policy[i] = js;
    end
    return value, policy
end


function solve(m; tol = 1e-6, max_iter = 1500)
    @unpack α, β, δ, k_grid = m

    value_func = zeros(length(k_grid))
    policy_func = zeros(length(k_grid))
    error = Inf;
    iter = 1

    while error > tol && iter <= max_iter
        value_func_update, policy_func_update = faster_bellman(m, value_func, policy_func);
        error = maximum(abs.(value_func_update - value_func));

        value_func = value_func_update;
        policy_func = policy_func_update;
        iter += 1
        #println("Iteration $iter, error = $error")
    end

    return value_func, policy_func
end

m = Model();


@benchmark solve(m)
2 Likes