I am solving an HJB equation in economics using Julia, but I am experiencing performance issues. My Julia implementation takes about 90 seconds, slower than Python (~76s) and significantly slower than MATLAB (~39s). I’m new to Julia and have already applied several performance optimizations, such as inlining functions and pre-allocating vectors/matrices outside loops. However, the performance is still unsatisfactory.
Profiling shows that most of the time is spent solving the linear system coef \ b
. I suspect there might be a more efficient way to handle this in Julia—at least as efficiently as in Python.
Does anyone have suggestions for improving the speed?
For reference, I’m running this on a MacBook Air M3.
Thanks in advance for your help!
using SparseArrays
import LinearAlgebra, LinearSolve
using SuiteSparseGraphBLAS
using Interpolations
using Plots
using Random
using Statistics
using Printf
using Profile
mutable struct RealBusinessCycleModel
alpha::Float64
gamma::Float64
rho::Float64
delta::Float64
zbar::Float64
theta::Float64
sigma::Float64
kbar::Float64
function RealBusinessCycleModel(alpha::Float64, gamma::Float64, rho::Float64, delta::Float64, zbar::Float64, theta::Float64, sigma::Float64)
kbar = ((alpha * zbar) / (rho + delta))^(1 / (1 - alpha))
new(alpha, gamma, rho, delta, zbar, theta, sigma, kbar)
end
end
function get_grid(model::RealBusinessCycleModel, k_min::Float64, k_max::Float64, z_min::Float64, z_max::Float64, I::Int, J::Int)::Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64}
k_grid::Vector{Float64} = range(k_min, k_max, length=I)
z_grid::Vector{Float64} = range(z_min, z_max, length=J)
kk::Matrix{Float64} = [k for k in k_grid, z in z_grid]
zz::Matrix{Float64} = [z for k in k_grid, z in z_grid]
kk_flat::Vector{Float64} = vec(kk)
zz_flat::Vector{Float64} = vec(zz)
dk::Float64 = (k_max - k_min) / (I - 1)
dz::Float64 = (z_max - z_min) / (J - 1)
return (k_grid, z_grid, kk_flat, zz_flat, dk, dz)
end
# Utility function
@inline function u(model::RealBusinessCycleModel, c::Vector{Float64})::Vector{Float64}
if model.gamma != 1.0
return (max.(c, 1e-6) .^ (1 .- model.gamma) .- 1) ./ (1 .- model.gamma)
else
return log.(max.(c, 1e-6))
end
end
# Derivative of the utility function
@inline function u_prime(model::RealBusinessCycleModel, c::Vector{Float64})::Vector{Float64}
return max.(c, 1e-6) .^ (-model.gamma)
end
# Inverse of the derivative of the utility function
@inline function u_prime_inv(model::RealBusinessCycleModel, x::Vector{Float64})::Vector{Float64}
return max.(x, 1e-6) .^ (-1 / model.gamma)
end
# Production function
@inline function f(model::RealBusinessCycleModel, k::Vector{Float64}, z::Vector{Float64})::Vector{Float64}
return max.(z, 0.0) .* max.(k, 0.0) .^ model.alpha
end
# Validate transition matrix (each row should sum nearly zero)
function validate_transition_matrix(T::AbstractMatrix{Float64})::Bool
err = maximum(abs.(sum(T, dims=2)))
if err > 1e-10
println("Error: maximum absolute row sum = ", err)
display(LinearAlgebra.Matrix(T))
throw(ArgumentError("T is not a proper transition matrix."))
end
return true
end
function construct_S_Dk(S_B::Vector{Float64}, S_F::Vector{Float64}, dk::Float64, I::Int, J::Int)::SparseMatrixCSC{Float64, Int64}
S_B_minus::Vector{Float64} = min.(S_B, 0.0)
S_F_plus::Vector{Float64} = max.(S_F, 0.0)
# println(S_B_minus)
# println(S_F_plus)
lower_diag::Vector{Float64} = -S_B_minus ./ dk
center_diag::Vector{Float64} = (S_B_minus .- S_F_plus) ./ dk
upper_diag::Vector{Float64} = S_F_plus ./ dk
diags = Dict(
-1 => lower_diag[2:end],
0 => center_diag,
1 => upper_diag[1:end-1]
)
return spdiagm(diags...)
end
function solve(model::RealBusinessCycleModel; Delta::Float64=100.0, tol::Float64=1e-6, max_iter::Int=1000, I::Int=20, J::Int=20, bounds::Union{Tuple{Float64, Float64, Float64, Float64}, Nothing}=nothing, n_sigma::Union{Tuple{Float64, Float64}, Float64, Nothing}=nothing)
l_sigma::Float64 = r_sigma::Float64 = 0.0
k_min::Float64, k_max::Float64, z_min::Float64, z_max::Float64 = 0.0, 0.0, 0.0, 0.0
if !isnothing(bounds)
k_min, k_max, z_min, z_max = bounds
elseif !isnothing(n_sigma)
if isa(n_sigma, Tuple) && length(n_sigma) == 2
l_sigma, r_sigma = n_sigma
else
l_sigma = r_sigma = n_sigma
end
z_min = model.zbar - l_sigma * model.sigma
z_max = model.zbar + r_sigma * model.sigma
k_min = (1 - l_sigma * model.sigma / model.zbar)^(1 / (1 - model.alpha)) * model.kbar
k_max = (1 + r_sigma * model.sigma / model.zbar)^(1 / (1 - model.alpha)) * model.kbar
else
error("Must specify bounds or n_sigma")
end
@assert k_min > 0 && z_min > 0 && k_max > k_min && z_max > z_min
k_grid::Vector{Float64}, z_grid::Vector{Float64}, kk::Vector{Float64}, zz::Vector{Float64}, dk::Float64, dz::Float64 = get_grid(model, k_min, k_max, z_min, z_max, I, J)
_Dk_F::SparseMatrixCSC{Float64, Int64} = spdiagm(0 => [-ones(I - 1); 0.0], 1 => ones(I - 1))
Dk_F::SparseMatrixCSC{Float64, Int64} = kron(sparse(LinearAlgebra.I, J, J), (1 / dk) * _Dk_F)
_Dk_B::SparseMatrixCSC{Float64, Int64} = spdiagm(-1 => -ones(I - 1), 0 => [0.0; ones(I - 1)])
Dk_B::SparseMatrixCSC{Float64, Int64} = kron(sparse(LinearAlgebra.I, J, J), (1 / dk) * _Dk_B)
mu::Vector{Float64} = model.theta * (model.zbar .- zz)
_Dz_F::SparseMatrixCSC{Float64, Int64} = spdiagm(0 => [-ones(J - 1); 0.0], 1 => ones(J - 1))
Dz_F::SparseMatrixCSC{Float64, Int64} = kron((1 / dz) * _Dz_F, sparse(LinearAlgebra.I, I, I))
_Dz_B::SparseMatrixCSC{Float64, Int64} = spdiagm(-1 => -ones(J - 1), 0 => [0.0; ones(J - 1)])
Dz_B::SparseMatrixCSC{Float64, Int64} = kron((1 / dz) * _Dz_B, sparse(LinearAlgebra.I, I, I))
_Dzz::SparseMatrixCSC{Float64, Int64} = spdiagm(-1 => ones(J - 1), 0 => [-1.0; fill(-2.0, J - 2); -1.0], 1 => ones(J - 1))
Dzz::SparseMatrixCSC{Float64, Int64} = kron((1 / dz^2) * _Dzz, sparse(LinearAlgebra.I, I, I))
Dz_U::SparseMatrixCSC{Float64, Int64} = spdiagm(mu .< 0) * Dz_B + spdiagm(mu .>= 0) * Dz_F
mu_Dz::SparseMatrixCSC{Float64, Int64} = spdiagm(mu) * Dz_U
half_sigma2_Dzz::SparseMatrixCSC{Float64, Int64} = 0.5 .* model.sigma.^2 .* Dzz
v0::Vector{Float64} = u(model, f(model, kk, zz)) / model.rho
v::Vector{Float64} = copy(v0)
c::Vector{Float64} = zeros(I * J)
status::Int = 0
# Allocate memory outside the loop to avoid reallocation and improve performance
vk_F::Vector{Float64} = zeros(I * J)
vk_F_2d::Matrix{Float64} = zeros(I, J)
vk_B::Vector{Float64} = zeros(I * J)
vk_B_2d::Matrix{Float64} = zeros(I, J)
vk_bar::Vector{Float64} = zeros(I * J)
vk_U::Vector{Float64} = zeros(I * J)
S_F::Vector{Float64} = zeros(I * J)
S_B::Vector{Float64} = zeros(I * J)
S_Dk::SparseMatrixCSC{Float64, Int64} = spzeros(I * J, I * J)
A::SparseMatrixCSC{Float64, Int64} = spzeros(I * J, I * J)
coef::SparseMatrixCSC{Float64, Int64} = spzeros(I * J, I * J)
b::Vector{Float64} = zeros(I * J)
v_new::Vector{Float64} = zeros(I * J)
for iter::Int in 1:max_iter
vk_F = Dk_F * v
vk_F_2d = reshape(vk_F, (I, J))
vk_F_2d[I, :] = u_prime(model, f(model, k_grid[end] .* ones(J), z_grid) .- model.delta .* k_grid[end])
vk_F = vec(vk_F_2d)
vk_B = Dk_B * v
vk_B_2d = reshape(vk_B, (I, J))
vk_B_2d[1, :] = u_prime(model, f(model, k_grid[1] .* ones(J), z_grid) .- model.delta .* k_grid[1])
vk_B = vec(vk_B_2d)
vk_bar = u_prime(model, f(model, kk, zz) .- model.delta .* kk)
S_F = f(model, kk, zz) .- model.delta .* kk .- u_prime_inv(model, vk_F)
S_B = f(model, kk, zz) .- model.delta .* kk .- u_prime_inv(model, vk_B)
indicator_F::Vector{Float64} = S_F .> 0
indicator_B::Vector{Float64} = S_B .< 0
indicator_bar::Vector{Float64} = 1 .- indicator_F .- indicator_B
vk_U = indicator_F .* vk_F + indicator_B .* vk_B + indicator_bar .* vk_bar
c = u_prime_inv(model, vk_U)
S_Dk = construct_S_Dk(S_B, S_F, dk, I, J)
# validate_transition_matrix(S_Dk) || error("Transition matrix is invalid")
# validate_transition_matrix(mu_Dz) || error("Transition matrix is invalid")
# validate_transition_matrix(half_sigma2_Dzz) || error("Transition matrix is invalid")
# println("Iteration $iter")
# display(Matrix(S_Dk))
# display(Matrix(mu_Dz))
# display(Matrix(half_sigma2_Dzz))
# display(v)
# display(c)
A = S_Dk + mu_Dz + half_sigma2_Dzz
coef = (model.rho + 1 / Delta) * sparse(LinearAlgebra.I, I * J, I * J) - A
b = u(model, c) .+ v ./ Delta
# SparseArrays.ldiv!(v_new, coef, b)
prob = LinearSolve.LinearProblem(coef, b)
v_new = LinearSolve.solve(prob).u
err::Float64 = LinearAlgebra.norm(v_new - v)
if err < tol
println("Converged in $iter iterations")
status = 1
v .= v_new
break
end
v .= v_new
end
if status == 0
println("Failed to converge")
end
return v, c, k_grid, z_grid
end
function main()
alpha::Float64 = 0.3
gamma::Float64 = 2.0
rho::Float64 = 0.05
delta::Float64 = 0.05
zbar::Float64 = 1.0
theta::Float64 = 0.5
sigma::Float64 = 0.025
rbc::RealBusinessCycleModel = RealBusinessCycleModel(alpha, gamma, rho, delta, zbar, theta, sigma)
I::Int, J::Int = 1001, 1001
t_start::Float64 = time()
v::Vector{Float64}, c::Vector{Float64}, k::Vector{Float64}, z::Vector{Float64} = solve(rbc; I=I, J=J, n_sigma=(30.0, 30.0))
t_end::Float64 = time()
println(@sprintf("Time taken: %.6f seconds", t_end - t_start))
end
main()
Here is the Python code:
import numpy as np
import scipy.sparse
from scipy.interpolate import RegularGridInterpolator
import matplotlib.pyplot as plt
from typing import List, Tuple, Callable, Dict, Optional, Union
import time
class RealBusinessCycleModel:
def __init__(
self,
alpha: float,
gamma: float,
rho: float,
delta: float,
zbar: float,
theta: float,
sigma: float,
):
self.alpha = alpha
self.gamma = gamma
self.rho = rho
self.delta = delta
self.theta = theta
self.sigma = sigma
self.zbar = zbar
self.kbar = self.get_steady_state_k()
def get_steady_state_k(self) -> float:
# Solve for the steady state capital stock without the technology shock
k_ss = ((self.alpha * self.zbar) / (self.rho + self.delta)) ** (
1 / (1 - self.alpha)
)
return k_ss
def get_grid(
self, k_min: float, k_max: float, z_min: float, z_max: float, I: int, J: int
):
"""
Generates I equidistant points on the interval [k_min, k_max] and J equidistant points on the interval [z_min, z_max] for capital (k) and productivity (z) variables.
Parameters
----------
k_min : float
Minimum value of capital grid.
k_max : float
Maximum value of capital grid.
z_min : float
Minimum value of productivity grid.
z_max : float
Maximum value of productivity grid.
I : int
Number of points in the capital grid.
J : int
Number of points in the productivity grid.
Returns
-------
k_grid : np.ndarray
1D array of capital grid points.
z_grid : np.ndarray
1D array of productivity grid points.
kk_grid : np.ndarray
2D array of capital grid points, flattened in Fortran (column-major) order. Each column represents the capital grid.
zz_grid : np.ndarray
2D array of productivity grid points, flattened in Fortran (column-major) order. Each row represents the productivity grid.
dk : float
Step size in the capital grid.
dz : float
Step size in the productivity grid.
"""
k_grid = np.linspace(k_min, k_max, I)
z_grid = np.linspace(z_min, z_max, J)
kk_grid, zz_grid = np.meshgrid(k_grid, z_grid, indexing="ij")
dk = (k_max - k_min) / (I - 1)
dz = (z_max - z_min) / (J - 1)
return (
k_grid,
z_grid,
kk_grid.flatten(order="F"),
zz_grid.flatten(order="F"),
dk,
dz,
)
def u(self, c: np.ndarray) -> np.ndarray:
"""
CRRA utility function with consumption per capita c
"""
return (
(np.maximum(c, 1e-6) ** (1 - self.gamma) - 1) / (1 - self.gamma)
if self.gamma != 1
else np.log(np.maximum(c, 1e-6))
)
def u_prime(self, c: np.ndarray) -> np.ndarray:
"""
Derivative of the CRRA utility function with consumption per capita c
"""
return np.maximum(c, 1e-6) ** (-self.gamma)
def u_prime_inv(self, x: np.ndarray) -> np.ndarray:
"""
Inverse of the derivative of the CRRA utility function
"""
return np.maximum(x, 1e-6) ** (-1 / self.gamma)
def f(self, k: np.ndarray, z: np.ndarray) -> np.ndarray:
"""
Production function
"""
return np.maximum(z, 0) * np.maximum(k, 0) ** self.alpha
def _plot(
self,
k: np.ndarray,
z: np.ndarray,
v: np.ndarray,
vk: np.ndarray,
vz: np.ndarray,
vzz: np.ndarray,
c: np.ndarray,
title: str = "",
cmap: str = "viridis",
levels: Optional[Union[int, List[float], np.ndarray]] = None,
out=None,
) -> None:
# Ensure all input arrays are 1D and have the same length
assert (
k.ndim == 1
and z.ndim == 1
and v.ndim == 1
and vk.ndim == 1
and vz.ndim == 1
and vzz.ndim == 1
and c.ndim == 1
)
assert (
k.shape[0]
== z.shape[0]
== v.shape[0]
== vk.shape[0]
== vz.shape[0]
== vzz.shape[0]
== c.shape[0]
)
if levels is None:
levels = max(10, int(k.shape[0] ** (1 / 3)))
# Create the figure object
fig = plt.figure(figsize=(8, 12), dpi=100)
# 3D plot of the value function
ax1 = fig.add_subplot(
3, 2, 1, projection="3d", proj_type="ortho"
) # Specify 3D projection
ax1.plot_trisurf(k, z, v, cmap=cmap)
ax1.set_xlabel("k")
ax1.set_ylabel("z")
ax1.set_title("Value function $v$")
# Tricontourf of the value function
ax2 = fig.add_subplot(3, 2, 2)
contour = ax2.tricontourf(k, z, v, cmap=cmap, levels=levels)
fig.colorbar(contour, ax=ax2, format="%.2f")
ax2.set_title("Value function $v$")
ax2.set_xlabel("k")
ax2.set_ylabel("z")
# Tricontourf of the derivative of the value function vk
ax3 = fig.add_subplot(3, 2, 3)
contour = ax3.tricontourf(k, z, vk, cmap=cmap, levels=levels)
fig.colorbar(contour, ax=ax3, format="%.2f")
ax3.set_title("Derivative $v_k$")
ax3.set_xlabel("k")
ax3.set_ylabel("z")
# Tricontourf of the derivative of the value function vz
ax4 = fig.add_subplot(3, 2, 4)
contour = ax4.tricontourf(k, z, vz, cmap=cmap, levels=levels)
fig.colorbar(contour, ax=ax4, format="%.2f")
ax4.set_title("Derivative $v_z$")
ax4.set_xlabel("k")
ax4.set_ylabel("z")
# Tricontourf of the second derivative of the value function vzz
ax5 = fig.add_subplot(3, 2, 5)
contour = ax5.tricontourf(k, z, vzz, cmap=cmap, levels=levels)
fig.colorbar(contour, ax=ax5, format="%.2f")
ax5.set_title("Second derivative $v_{zz}$")
ax5.set_xlabel("k")
ax5.set_ylabel("z")
# Tricontourf of the consumption function c
ax6 = fig.add_subplot(3, 2, 6)
contour = ax6.tricontourf(k, z, c, cmap=cmap, levels=levels)
fig.colorbar(contour, ax=ax6, format="%.2f")
ax6.set_title("Consumption function $c$")
ax6.set_xlabel("k")
ax6.set_ylabel("z")
plt.suptitle(title)
plt.tight_layout()
if out is None:
plt.show()
else:
plt.savefig(out)
return
def validate_transition_matrix(
self, T: Union[np.ndarray, scipy.sparse.csr_matrix]
) -> bool:
error = np.max(np.abs(T.sum(axis=1)))
if error > 1e-10:
print(f"{error=}")
if not isinstance(T, np.ndarray):
T = T.todense()
print(f"{T=}")
raise ValueError("T is not a proper transition matrix.")
return True
def construct_S_Dk(
self, S_B: np.ndarray, S_F: np.ndarray, dk: float, I: int, J: int
) -> scipy.sparse.csr_matrix:
S_B_minus = np.minimum(S_B, 0)
S_F_plus = np.maximum(S_F, 0)
X = -S_B_minus / dk
Y = S_B_minus / dk - S_F_plus / dk
Z = S_F_plus / dk
# Construct matrix An
# Y0 Z0 00 00
# X1 Y1 Z1 00
# 00 X2 Y2 Z2
# 00 00 X2 Y3
return scipy.sparse.diags([Y, X[1:], Z[:-1]], [0, -1, 1], shape=(I * J, I * J))
def solve(
self,
Delta: float = 100.0,
tolerance: float = 1e-6,
max_iter: int = 1000,
I: int = 20,
J: int = 20,
bounds: Optional[Tuple[float, float, float, float]] = None,
n_sigma: Union[float, Tuple[float, float], None] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if bounds is not None:
k_min, k_max, z_min, z_max = bounds
elif n_sigma is not None:
if isinstance(n_sigma, tuple) and len(n_sigma) == 2:
l_sigma, r_sigma = n_sigma
elif isinstance(n_sigma, float) or isinstance(n_sigma, int):
l_sigma = r_sigma = n_sigma
else:
raise ValueError("n_sigma must be a float or a tuple of two floats.")
z_min, z_max = (
self.zbar - l_sigma * self.sigma,
self.zbar + r_sigma * self.sigma,
)
k_min = (1 - l_sigma * self.sigma / self.zbar) ** (
1 / (1 - self.alpha)
) * self.kbar
k_max = (1 + r_sigma * self.sigma / self.zbar) ** (
1 / (1 - self.alpha)
) * self.kbar
else:
raise ValueError("Either bounds or n_sigma must be specified.")
assert k_min > 0, "k_min must be positive."
assert z_min > 0, "z_min must be positive."
assert k_max > k_min, "k_max must be greater than k_min."
assert z_max > z_min, "z_max must be greater than z_min."
k, z, kk, zz, dk, dz = self.get_grid(k_min, k_max, z_min, z_max, I, J)
eye_J = scipy.sparse.eye(J)
eye_I = scipy.sparse.eye(I)
eye_IJ = scipy.sparse.eye(I * J)
# Construct Dk_F and Dk_B, mu @ Dz, and 0.5 * sigma^2 @ Dzz
_Dk_F = scipy.sparse.diags(
diagonals=[np.hstack([-np.ones(I - 1), 0]), np.ones(I - 1)],
offsets=[0, 1],
shape=(I, I),
format="csr",
)
_Dk_B = scipy.sparse.diags(
diagonals=[-np.ones(I - 1), np.hstack([0, np.ones(I - 1)])],
offsets=[-1, 0],
shape=(I, I),
format="csr",
)
Dk_F = scipy.sparse.kron(eye_J, (dk) ** (-1) * _Dk_F)
Dk_B = scipy.sparse.kron(eye_J, (dk) ** (-1) * _Dk_B)
mu = self.theta * (self.zbar - zz)
indicator_F = (mu >= 0).astype(np.float64)
indicator_B = (mu < 0).astype(np.float64)
indicator_M = 1 - indicator_F - indicator_B
_Dz_F = scipy.sparse.diags(
diagonals=[np.hstack([-np.ones(J - 1), 0]), np.ones(J - 1)],
offsets=[0, 1],
shape=(J, J),
format="csr",
)
_Dz_B = scipy.sparse.diags(
diagonals=[-np.ones(J - 1), np.hstack([0, np.ones(J - 1)])],
offsets=[-1, 0],
shape=(J, J),
format="csr",
)
Dz_F = scipy.sparse.kron((dz) ** (-1) * _Dz_F, eye_I)
Dz_B = scipy.sparse.kron((dz) ** (-1) * _Dz_B, eye_I)
_Dzz = scipy.sparse.diags(
diagonals=[
np.ones(J - 1),
np.hstack([-1, -2 * np.ones(J - 2), -1]),
np.ones(J - 1),
],
offsets=[-1, 0, 1],
shape=(J, J),
format="csr",
)
# self.validate_transition_matrix(_Dzz)
Dz_U = (
scipy.sparse.diags(indicator_B) @ Dz_B
+ scipy.sparse.diags(indicator_F) @ Dz_F
)
Dzz = scipy.sparse.kron(dz ** (-2) * _Dzz, eye_I)
mu_Dz = scipy.sparse.diags(mu) @ Dz_U
half_sigma2_Dzz = 0.5 * self.sigma**2 * Dzz
# self.validate_transition_matrix(mu_Dz)
# self.validate_transition_matrix(Dzz)
# Initialize v0
v0 = self.u(self.f(kk, zz)) / self.rho
v = v0.copy()
v_new = v0.copy()
status = 0
for iter in range(max_iter):
vk_F = Dk_F @ v
vk_F.reshape(I, J, order="F")[I - 1, :] = self.u_prime(
self.f(k_max * np.ones(J), z) - self.delta * k_max
)
vk_F = vk_F.flatten(order="F")
vk_B = Dk_B @ v
vk_B.reshape(I, J, order="F")[0, :] = self.u_prime(
self.f(k_min * np.ones(J), z) - self.delta * k_min
)
vk_B = vk_B.flatten(order="F")
vk_bar = self.u_prime(self.f(kk, zz) - self.delta * kk)
S_F = self.f(kk, zz) - self.delta * kk - self.u_prime_inv(vk_F)
S_B = self.f(kk, zz) - self.delta * kk - self.u_prime_inv(vk_B)
indicator_F = (S_F > 0).astype(int)
indicator_B = (S_B < 0).astype(int)
indicator_bar = 1 - indicator_F - indicator_B
vk_U = indicator_F * vk_F + indicator_B * vk_B + indicator_bar * vk_bar
c = self.u_prime_inv(vk_U)
S_Dk = self.construct_S_Dk(S_B, S_F, dk, I, J)
# print("iteration", iter)
# with np.printoptions(precision=5, suppress=True, linewidth=220):
# print(f"S_Dk:\n{S_Dk.todense()}")
# print(f"mu_Dz:\n{mu_Dz.todense()}")
# print(f"half_sigma2_Dzz:\n{half_sigma2_Dzz.todense()}")
# print(f"v:\n{v}")
# print(f"c:\n{c}")
A = S_Dk + mu_Dz + half_sigma2_Dzz
coef = (self.rho + 1 / Delta) * eye_IJ - A
b = self.u(c) + v / Delta
v_new = scipy.sparse.linalg.spsolve(coef, b)
absolute_error = np.linalg.norm(v_new - v)
if absolute_error < tolerance:
print(f"Converged in {iter} iterations")
vz = Dz_U @ v_new
vzz = Dzz @ v_new
# self._plot(
# kk,
# zz,
# v_new,
# vk_U,
# vz,
# vzz,
# c,
# title="Converged solution",
# out="converged_solution.pdf",
# )
status = 1
break
v = v_new
if status == 0:
print(f"Failed to converge in {max_iter} iterations")
return v_new, c, k, z
if __name__ == "__main__":
# Example usage
alpha = 0.3
gamma = 2.0
rho = 0.05
delta = 0.05
zbar = 1.0
theta = 0.5
sigma = 0.025
rbc = RealBusinessCycleModel(
alpha=alpha,
gamma=gamma,
rho=rho,
delta=delta,
zbar=zbar,
theta=theta,
sigma=sigma,
)
# Solve the model
I, J = 1001, 1001
start_time = time.time()
v_new, c, k, z = rbc.solve(Delta=100, I=I, J=J, n_sigma=(30, 30))
end_time = time.time()
print(f"Time taken: {end_time - start_time:.6f} seconds", flush=True)