My Julia code is slower than Python and Matlab

I am solving an HJB equation in economics using Julia, but I am experiencing performance issues. My Julia implementation takes about 90 seconds, slower than Python (~76s) and significantly slower than MATLAB (~39s). I’m new to Julia and have already applied several performance optimizations, such as inlining functions and pre-allocating vectors/matrices outside loops. However, the performance is still unsatisfactory.

Profiling shows that most of the time is spent solving the linear system coef \ b. I suspect there might be a more efficient way to handle this in Julia—at least as efficiently as in Python.

Does anyone have suggestions for improving the speed?

For reference, I’m running this on a MacBook Air M3.

Thanks in advance for your help!

using SparseArrays
import LinearAlgebra, LinearSolve
using SuiteSparseGraphBLAS
using Interpolations
using Plots
using Random
using Statistics
using Printf
using Profile

mutable struct RealBusinessCycleModel
    alpha::Float64
    gamma::Float64
    rho::Float64
    delta::Float64
    zbar::Float64
    theta::Float64
    sigma::Float64
    kbar::Float64

    function RealBusinessCycleModel(alpha::Float64, gamma::Float64, rho::Float64, delta::Float64, zbar::Float64, theta::Float64, sigma::Float64)
        kbar = ((alpha * zbar) / (rho + delta))^(1 / (1 - alpha))
        new(alpha, gamma, rho, delta, zbar, theta, sigma, kbar)
    end
end

function get_grid(model::RealBusinessCycleModel, k_min::Float64, k_max::Float64, z_min::Float64, z_max::Float64, I::Int, J::Int)::Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64}
    k_grid::Vector{Float64} = range(k_min, k_max, length=I)
    z_grid::Vector{Float64} = range(z_min, z_max, length=J)
    kk::Matrix{Float64} = [k for k in k_grid, z in z_grid]
    zz::Matrix{Float64} = [z for k in k_grid, z in z_grid]
    kk_flat::Vector{Float64} = vec(kk)
    zz_flat::Vector{Float64} = vec(zz)
    dk::Float64 = (k_max - k_min) / (I - 1)
    dz::Float64 = (z_max - z_min) / (J - 1)
    return (k_grid, z_grid, kk_flat, zz_flat, dk, dz)
end

# Utility function
@inline function u(model::RealBusinessCycleModel, c::Vector{Float64})::Vector{Float64}
    if model.gamma != 1.0
        return (max.(c, 1e-6) .^ (1 .- model.gamma) .- 1) ./ (1 .- model.gamma)
    else
        return log.(max.(c, 1e-6))
    end
end

# Derivative of the utility function
@inline function u_prime(model::RealBusinessCycleModel, c::Vector{Float64})::Vector{Float64}
    return max.(c, 1e-6) .^ (-model.gamma)
end

# Inverse of the derivative of the utility function
@inline function u_prime_inv(model::RealBusinessCycleModel, x::Vector{Float64})::Vector{Float64}
    return max.(x, 1e-6) .^ (-1 / model.gamma)
end

# Production function
@inline function f(model::RealBusinessCycleModel, k::Vector{Float64}, z::Vector{Float64})::Vector{Float64}
    return max.(z, 0.0) .* max.(k, 0.0) .^ model.alpha
end


# Validate transition matrix (each row should sum nearly zero)
function validate_transition_matrix(T::AbstractMatrix{Float64})::Bool
    err = maximum(abs.(sum(T, dims=2)))
    if err > 1e-10
        println("Error: maximum absolute row sum = ", err)
        display(LinearAlgebra.Matrix(T))
        throw(ArgumentError("T is not a proper transition matrix."))
    end
    return true
end

function construct_S_Dk(S_B::Vector{Float64}, S_F::Vector{Float64}, dk::Float64, I::Int, J::Int)::SparseMatrixCSC{Float64, Int64}
    S_B_minus::Vector{Float64} = min.(S_B, 0.0)
    S_F_plus::Vector{Float64} = max.(S_F, 0.0)
    # println(S_B_minus)
    # println(S_F_plus)
    lower_diag::Vector{Float64} = -S_B_minus ./ dk
    center_diag::Vector{Float64} = (S_B_minus .- S_F_plus) ./ dk
    upper_diag::Vector{Float64} = S_F_plus ./ dk
    diags = Dict(
        -1 => lower_diag[2:end],
        0 => center_diag,
        1 => upper_diag[1:end-1]
    )
    return spdiagm(diags...)
end

function solve(model::RealBusinessCycleModel; Delta::Float64=100.0, tol::Float64=1e-6, max_iter::Int=1000, I::Int=20, J::Int=20, bounds::Union{Tuple{Float64, Float64, Float64, Float64}, Nothing}=nothing, n_sigma::Union{Tuple{Float64, Float64}, Float64, Nothing}=nothing)
    l_sigma::Float64 = r_sigma::Float64 = 0.0
    k_min::Float64, k_max::Float64, z_min::Float64, z_max::Float64 = 0.0, 0.0, 0.0, 0.0
    if !isnothing(bounds)
        k_min, k_max, z_min, z_max = bounds
    elseif !isnothing(n_sigma)
        if isa(n_sigma, Tuple) && length(n_sigma) == 2
            l_sigma, r_sigma = n_sigma
        else
            l_sigma = r_sigma = n_sigma
        end
        z_min = model.zbar - l_sigma * model.sigma
        z_max = model.zbar + r_sigma * model.sigma
        k_min = (1 - l_sigma * model.sigma / model.zbar)^(1 / (1 - model.alpha)) * model.kbar
        k_max = (1 + r_sigma * model.sigma / model.zbar)^(1 / (1 - model.alpha)) * model.kbar
    else
        error("Must specify bounds or n_sigma")
    end

    @assert k_min > 0 && z_min > 0 && k_max > k_min && z_max > z_min

    k_grid::Vector{Float64}, z_grid::Vector{Float64}, kk::Vector{Float64}, zz::Vector{Float64}, dk::Float64, dz::Float64 = get_grid(model, k_min, k_max, z_min, z_max, I, J)

    _Dk_F::SparseMatrixCSC{Float64, Int64} = spdiagm(0 => [-ones(I - 1); 0.0], 1 => ones(I - 1))
    Dk_F::SparseMatrixCSC{Float64, Int64} = kron(sparse(LinearAlgebra.I, J, J), (1 / dk) * _Dk_F)

    _Dk_B::SparseMatrixCSC{Float64, Int64} = spdiagm(-1 => -ones(I - 1), 0 => [0.0; ones(I - 1)])
    Dk_B::SparseMatrixCSC{Float64, Int64} = kron(sparse(LinearAlgebra.I, J, J), (1 / dk) * _Dk_B)

    mu::Vector{Float64} = model.theta * (model.zbar .- zz)

    _Dz_F::SparseMatrixCSC{Float64, Int64} = spdiagm(0 => [-ones(J - 1); 0.0], 1 => ones(J - 1))
    Dz_F::SparseMatrixCSC{Float64, Int64} = kron((1 / dz) * _Dz_F, sparse(LinearAlgebra.I, I, I))
    _Dz_B::SparseMatrixCSC{Float64, Int64} = spdiagm(-1 => -ones(J - 1), 0 => [0.0; ones(J - 1)])
    Dz_B::SparseMatrixCSC{Float64, Int64} = kron((1 / dz) * _Dz_B, sparse(LinearAlgebra.I, I, I))

    _Dzz::SparseMatrixCSC{Float64, Int64} = spdiagm(-1 => ones(J - 1), 0 => [-1.0; fill(-2.0, J - 2); -1.0], 1 => ones(J - 1))
    Dzz::SparseMatrixCSC{Float64, Int64} = kron((1 / dz^2) * _Dzz, sparse(LinearAlgebra.I, I, I))
    Dz_U::SparseMatrixCSC{Float64, Int64} = spdiagm(mu .< 0) * Dz_B + spdiagm(mu .>= 0) * Dz_F
    mu_Dz::SparseMatrixCSC{Float64, Int64} = spdiagm(mu) * Dz_U
    half_sigma2_Dzz::SparseMatrixCSC{Float64, Int64} = 0.5 .* model.sigma.^2 .* Dzz

    v0::Vector{Float64} = u(model, f(model, kk, zz)) / model.rho
    v::Vector{Float64} = copy(v0)
    c::Vector{Float64} = zeros(I * J)
    status::Int = 0

    # Allocate memory outside the loop to avoid reallocation and improve performance
    vk_F::Vector{Float64} = zeros(I * J)
    vk_F_2d::Matrix{Float64} = zeros(I, J)
    vk_B::Vector{Float64} = zeros(I * J)
    vk_B_2d::Matrix{Float64} = zeros(I, J)
    vk_bar::Vector{Float64} = zeros(I * J)
    vk_U::Vector{Float64} = zeros(I * J)
    S_F::Vector{Float64} = zeros(I * J)
    S_B::Vector{Float64} = zeros(I * J)
    S_Dk::SparseMatrixCSC{Float64, Int64} = spzeros(I * J, I * J)
    A::SparseMatrixCSC{Float64, Int64} = spzeros(I * J, I * J)
    coef::SparseMatrixCSC{Float64, Int64} = spzeros(I * J, I * J)
    b::Vector{Float64} = zeros(I * J)
    v_new::Vector{Float64} = zeros(I * J)
    

    for iter::Int in 1:max_iter
        vk_F = Dk_F * v
        vk_F_2d = reshape(vk_F, (I, J))
        vk_F_2d[I, :] = u_prime(model, f(model, k_grid[end] .* ones(J), z_grid) .- model.delta .* k_grid[end])
        vk_F = vec(vk_F_2d)

        vk_B = Dk_B * v
        vk_B_2d = reshape(vk_B, (I, J))
        vk_B_2d[1, :] = u_prime(model, f(model, k_grid[1] .* ones(J), z_grid) .- model.delta .* k_grid[1])
        vk_B = vec(vk_B_2d)

        vk_bar = u_prime(model, f(model, kk, zz) .- model.delta .* kk)
        S_F = f(model, kk, zz) .- model.delta .* kk .- u_prime_inv(model, vk_F)
        S_B = f(model, kk, zz) .- model.delta .* kk .- u_prime_inv(model, vk_B)

        indicator_F::Vector{Float64} = S_F .> 0
        indicator_B::Vector{Float64} = S_B .< 0
        indicator_bar::Vector{Float64} = 1 .- indicator_F .- indicator_B
        vk_U = indicator_F .* vk_F + indicator_B .* vk_B + indicator_bar .* vk_bar
        c = u_prime_inv(model, vk_U)

        S_Dk = construct_S_Dk(S_B, S_F, dk, I, J)


        # validate_transition_matrix(S_Dk) || error("Transition matrix is invalid")
        # validate_transition_matrix(mu_Dz) || error("Transition matrix is invalid")
        # validate_transition_matrix(half_sigma2_Dzz) || error("Transition matrix is invalid")

        # println("Iteration $iter")
        # display(Matrix(S_Dk))
        # display(Matrix(mu_Dz))
        # display(Matrix(half_sigma2_Dzz))
        # display(v)
        # display(c)


        A = S_Dk + mu_Dz + half_sigma2_Dzz
        coef = (model.rho + 1 / Delta) * sparse(LinearAlgebra.I, I * J, I * J) - A
        b = u(model, c) .+ v ./ Delta

        # SparseArrays.ldiv!(v_new, coef, b)
        prob = LinearSolve.LinearProblem(coef, b)
        v_new = LinearSolve.solve(prob).u
        err::Float64 = LinearAlgebra.norm(v_new - v)
        if err < tol
            println("Converged in $iter iterations")
            status = 1
            v .= v_new
            break
        end
        v .= v_new
    end

    if status == 0
        println("Failed to converge")
    end
    return v, c, k_grid, z_grid
end

function main()
    alpha::Float64 = 0.3
    gamma::Float64 = 2.0
    rho::Float64  = 0.05
    delta::Float64 = 0.05
    zbar::Float64 = 1.0
    theta::Float64 = 0.5
    sigma::Float64 = 0.025

    rbc::RealBusinessCycleModel = RealBusinessCycleModel(alpha, gamma, rho, delta, zbar, theta, sigma)
    I::Int, J::Int  = 1001, 1001
    t_start::Float64  = time()
    v::Vector{Float64}, c::Vector{Float64}, k::Vector{Float64}, z::Vector{Float64} = solve(rbc; I=I, J=J, n_sigma=(30.0, 30.0))
    t_end::Float64  = time()
    println(@sprintf("Time taken: %.6f seconds", t_end - t_start))

end

main()

Here is the Python code:

import numpy as np
import scipy.sparse
from scipy.interpolate import RegularGridInterpolator
import matplotlib.pyplot as plt
from typing import List, Tuple, Callable, Dict, Optional, Union

import time


class RealBusinessCycleModel:
    def __init__(
        self,
        alpha: float,
        gamma: float,
        rho: float,
        delta: float,
        zbar: float,
        theta: float,
        sigma: float,
    ):
        self.alpha = alpha
        self.gamma = gamma
        self.rho = rho
        self.delta = delta
        self.theta = theta
        self.sigma = sigma
        self.zbar = zbar
        self.kbar = self.get_steady_state_k()

    def get_steady_state_k(self) -> float:
        # Solve for the steady state capital stock without the technology shock
        k_ss = ((self.alpha * self.zbar) / (self.rho + self.delta)) ** (
            1 / (1 - self.alpha)
        )
        return k_ss

    def get_grid(
        self, k_min: float, k_max: float, z_min: float, z_max: float, I: int, J: int
    ):
        """
        Generates I equidistant points on the interval [k_min, k_max] and J equidistant points on the interval [z_min, z_max] for capital (k) and productivity (z) variables.

        Parameters
        ----------
        k_min : float
            Minimum value of capital grid.
        k_max : float
            Maximum value of capital grid.
        z_min : float
            Minimum value of productivity grid.
        z_max : float
            Maximum value of productivity grid.
        I : int
            Number of points in the capital grid.
        J : int
            Number of points in the productivity grid.

        Returns
        -------
        k_grid : np.ndarray
            1D array of capital grid points.
        z_grid : np.ndarray
            1D array of productivity grid points.
        kk_grid : np.ndarray
            2D array of capital grid points, flattened in Fortran (column-major) order.  Each column represents the capital grid.
        zz_grid : np.ndarray
            2D array of productivity grid points, flattened in Fortran (column-major) order. Each row represents the productivity grid.
        dk : float
            Step size in the capital grid.
        dz : float
            Step size in the productivity grid.
        """
        k_grid = np.linspace(k_min, k_max, I)
        z_grid = np.linspace(z_min, z_max, J)
        kk_grid, zz_grid = np.meshgrid(k_grid, z_grid, indexing="ij")
        dk = (k_max - k_min) / (I - 1)
        dz = (z_max - z_min) / (J - 1)
        return (
            k_grid,
            z_grid,
            kk_grid.flatten(order="F"),
            zz_grid.flatten(order="F"),
            dk,
            dz,
        )

    def u(self, c: np.ndarray) -> np.ndarray:
        """
        CRRA utility function with consumption per capita c
        """
        return (
            (np.maximum(c, 1e-6) ** (1 - self.gamma) - 1) / (1 - self.gamma)
            if self.gamma != 1
            else np.log(np.maximum(c, 1e-6))
        )

    def u_prime(self, c: np.ndarray) -> np.ndarray:
        """
        Derivative of the CRRA utility function with consumption per capita c
        """
        return np.maximum(c, 1e-6) ** (-self.gamma)

    def u_prime_inv(self, x: np.ndarray) -> np.ndarray:
        """
        Inverse of the derivative of the CRRA utility function
        """
        return np.maximum(x, 1e-6) ** (-1 / self.gamma)

    def f(self, k: np.ndarray, z: np.ndarray) -> np.ndarray:
        """
        Production function
        """
        return np.maximum(z, 0) * np.maximum(k, 0) ** self.alpha

    def _plot(
        self,
        k: np.ndarray,
        z: np.ndarray,
        v: np.ndarray,
        vk: np.ndarray,
        vz: np.ndarray,
        vzz: np.ndarray,
        c: np.ndarray,
        title: str = "",
        cmap: str = "viridis",
        levels: Optional[Union[int, List[float], np.ndarray]] = None,
        out=None,
    ) -> None:

        # Ensure all input arrays are 1D and have the same length
        assert (
            k.ndim == 1
            and z.ndim == 1
            and v.ndim == 1
            and vk.ndim == 1
            and vz.ndim == 1
            and vzz.ndim == 1
            and c.ndim == 1
        )
        assert (
            k.shape[0]
            == z.shape[0]
            == v.shape[0]
            == vk.shape[0]
            == vz.shape[0]
            == vzz.shape[0]
            == c.shape[0]
        )

        if levels is None:
            levels = max(10, int(k.shape[0] ** (1 / 3)))

        # Create the figure object
        fig = plt.figure(figsize=(8, 12), dpi=100)

        # 3D plot of the value function
        ax1 = fig.add_subplot(
            3, 2, 1, projection="3d", proj_type="ortho"
        )  # Specify 3D projection
        ax1.plot_trisurf(k, z, v, cmap=cmap)
        ax1.set_xlabel("k")
        ax1.set_ylabel("z")
        ax1.set_title("Value function $v$")

        # Tricontourf of the value function
        ax2 = fig.add_subplot(3, 2, 2)
        contour = ax2.tricontourf(k, z, v, cmap=cmap, levels=levels)
        fig.colorbar(contour, ax=ax2, format="%.2f")
        ax2.set_title("Value function $v$")
        ax2.set_xlabel("k")
        ax2.set_ylabel("z")

        # Tricontourf of the derivative of the value function vk
        ax3 = fig.add_subplot(3, 2, 3)
        contour = ax3.tricontourf(k, z, vk, cmap=cmap, levels=levels)
        fig.colorbar(contour, ax=ax3, format="%.2f")
        ax3.set_title("Derivative $v_k$")
        ax3.set_xlabel("k")
        ax3.set_ylabel("z")

        # Tricontourf of the derivative of the value function vz
        ax4 = fig.add_subplot(3, 2, 4)
        contour = ax4.tricontourf(k, z, vz, cmap=cmap, levels=levels)
        fig.colorbar(contour, ax=ax4, format="%.2f")
        ax4.set_title("Derivative $v_z$")
        ax4.set_xlabel("k")
        ax4.set_ylabel("z")

        # Tricontourf of the second derivative of the value function vzz
        ax5 = fig.add_subplot(3, 2, 5)
        contour = ax5.tricontourf(k, z, vzz, cmap=cmap, levels=levels)
        fig.colorbar(contour, ax=ax5, format="%.2f")
        ax5.set_title("Second derivative $v_{zz}$")
        ax5.set_xlabel("k")
        ax5.set_ylabel("z")

        # Tricontourf of the consumption function c
        ax6 = fig.add_subplot(3, 2, 6)
        contour = ax6.tricontourf(k, z, c, cmap=cmap, levels=levels)
        fig.colorbar(contour, ax=ax6, format="%.2f")
        ax6.set_title("Consumption function $c$")
        ax6.set_xlabel("k")
        ax6.set_ylabel("z")

        plt.suptitle(title)
        plt.tight_layout()
        if out is None:
            plt.show()
        else:
            plt.savefig(out)
        return

    def validate_transition_matrix(
        self, T: Union[np.ndarray, scipy.sparse.csr_matrix]
    ) -> bool:
        error = np.max(np.abs(T.sum(axis=1)))
        if error > 1e-10:
            print(f"{error=}")
            if not isinstance(T, np.ndarray):
                T = T.todense()
            print(f"{T=}")
            raise ValueError("T is not a proper transition matrix.")
        return True

    def construct_S_Dk(
        self, S_B: np.ndarray, S_F: np.ndarray, dk: float, I: int, J: int
    ) -> scipy.sparse.csr_matrix:
        S_B_minus = np.minimum(S_B, 0)
        S_F_plus = np.maximum(S_F, 0)

        X = -S_B_minus / dk
        Y = S_B_minus / dk - S_F_plus / dk
        Z = S_F_plus / dk
        # Construct matrix An
        # Y0 Z0 00 00
        # X1 Y1 Z1 00
        # 00 X2 Y2 Z2
        # 00 00 X2 Y3
        return scipy.sparse.diags([Y, X[1:], Z[:-1]], [0, -1, 1], shape=(I * J, I * J))

    def solve(
        self,
        Delta: float = 100.0,
        tolerance: float = 1e-6,
        max_iter: int = 1000,
        I: int = 20,
        J: int = 20,
        bounds: Optional[Tuple[float, float, float, float]] = None,
        n_sigma: Union[float, Tuple[float, float], None] = None,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        if bounds is not None:
            k_min, k_max, z_min, z_max = bounds
        elif n_sigma is not None:
            if isinstance(n_sigma, tuple) and len(n_sigma) == 2:
                l_sigma, r_sigma = n_sigma
            elif isinstance(n_sigma, float) or isinstance(n_sigma, int):
                l_sigma = r_sigma = n_sigma
            else:
                raise ValueError("n_sigma must be a float or a tuple of two floats.")
            z_min, z_max = (
                self.zbar - l_sigma * self.sigma,
                self.zbar + r_sigma * self.sigma,
            )
            k_min = (1 - l_sigma * self.sigma / self.zbar) ** (
                1 / (1 - self.alpha)
            ) * self.kbar
            k_max = (1 + r_sigma * self.sigma / self.zbar) ** (
                1 / (1 - self.alpha)
            ) * self.kbar
        else:
            raise ValueError("Either bounds or n_sigma must be specified.")

        assert k_min > 0, "k_min must be positive."
        assert z_min > 0, "z_min must be positive."
        assert k_max > k_min, "k_max must be greater than k_min."
        assert z_max > z_min, "z_max must be greater than z_min."

        k, z, kk, zz, dk, dz = self.get_grid(k_min, k_max, z_min, z_max, I, J)

        eye_J = scipy.sparse.eye(J)
        eye_I = scipy.sparse.eye(I)
        eye_IJ = scipy.sparse.eye(I * J)

        # Construct Dk_F and Dk_B, mu @ Dz, and 0.5 * sigma^2 @ Dzz

        _Dk_F = scipy.sparse.diags(
            diagonals=[np.hstack([-np.ones(I - 1), 0]), np.ones(I - 1)],
            offsets=[0, 1],
            shape=(I, I),
            format="csr",
        )

        _Dk_B = scipy.sparse.diags(
            diagonals=[-np.ones(I - 1), np.hstack([0, np.ones(I - 1)])],
            offsets=[-1, 0],
            shape=(I, I),
            format="csr",
        )

        Dk_F = scipy.sparse.kron(eye_J, (dk) ** (-1) * _Dk_F)
        Dk_B = scipy.sparse.kron(eye_J, (dk) ** (-1) * _Dk_B)

        mu = self.theta * (self.zbar - zz)

        indicator_F = (mu >= 0).astype(np.float64)
        indicator_B = (mu < 0).astype(np.float64)
        indicator_M = 1 - indicator_F - indicator_B

        _Dz_F = scipy.sparse.diags(
            diagonals=[np.hstack([-np.ones(J - 1), 0]), np.ones(J - 1)],
            offsets=[0, 1],
            shape=(J, J),
            format="csr",
        )
        _Dz_B = scipy.sparse.diags(
            diagonals=[-np.ones(J - 1), np.hstack([0, np.ones(J - 1)])],
            offsets=[-1, 0],
            shape=(J, J),
            format="csr",
        )

        Dz_F = scipy.sparse.kron((dz) ** (-1) * _Dz_F, eye_I)
        Dz_B = scipy.sparse.kron((dz) ** (-1) * _Dz_B, eye_I)

        _Dzz = scipy.sparse.diags(
            diagonals=[
                np.ones(J - 1),
                np.hstack([-1, -2 * np.ones(J - 2), -1]),
                np.ones(J - 1),
            ],
            offsets=[-1, 0, 1],
            shape=(J, J),
            format="csr",
        )

        # self.validate_transition_matrix(_Dzz)

        Dz_U = (
            scipy.sparse.diags(indicator_B) @ Dz_B
            + scipy.sparse.diags(indicator_F) @ Dz_F
        )
        Dzz = scipy.sparse.kron(dz ** (-2) * _Dzz, eye_I)

        mu_Dz = scipy.sparse.diags(mu) @ Dz_U
        half_sigma2_Dzz = 0.5 * self.sigma**2 * Dzz

        # self.validate_transition_matrix(mu_Dz)
        # self.validate_transition_matrix(Dzz)

        # Initialize v0
        v0 = self.u(self.f(kk, zz)) / self.rho
        v = v0.copy()
        v_new = v0.copy()
        status = 0

        for iter in range(max_iter):
            vk_F = Dk_F @ v
            vk_F.reshape(I, J, order="F")[I - 1, :] = self.u_prime(
                self.f(k_max * np.ones(J), z) - self.delta * k_max
            )
            vk_F = vk_F.flatten(order="F")

            vk_B = Dk_B @ v
            vk_B.reshape(I, J, order="F")[0, :] = self.u_prime(
                self.f(k_min * np.ones(J), z) - self.delta * k_min
            )
            vk_B = vk_B.flatten(order="F")

            vk_bar = self.u_prime(self.f(kk, zz) - self.delta * kk)

            S_F = self.f(kk, zz) - self.delta * kk - self.u_prime_inv(vk_F)
            S_B = self.f(kk, zz) - self.delta * kk - self.u_prime_inv(vk_B)

            indicator_F = (S_F > 0).astype(int)
            indicator_B = (S_B < 0).astype(int)
            indicator_bar = 1 - indicator_F - indicator_B
            vk_U = indicator_F * vk_F + indicator_B * vk_B + indicator_bar * vk_bar
            c = self.u_prime_inv(vk_U)

            S_Dk = self.construct_S_Dk(S_B, S_F, dk, I, J)

            # print("iteration", iter)
            # with np.printoptions(precision=5, suppress=True, linewidth=220):
            #     print(f"S_Dk:\n{S_Dk.todense()}")
            #     print(f"mu_Dz:\n{mu_Dz.todense()}")
            #     print(f"half_sigma2_Dzz:\n{half_sigma2_Dzz.todense()}")
            #     print(f"v:\n{v}")
            #     print(f"c:\n{c}")

            A = S_Dk + mu_Dz + half_sigma2_Dzz
            coef = (self.rho + 1 / Delta) * eye_IJ - A
            b = self.u(c) + v / Delta

            v_new = scipy.sparse.linalg.spsolve(coef, b)
            absolute_error = np.linalg.norm(v_new - v)
            if absolute_error < tolerance:
                print(f"Converged in {iter} iterations")
                vz = Dz_U @ v_new
                vzz = Dzz @ v_new
                # self._plot(
                #     kk,
                #     zz,
                #     v_new,
                #     vk_U,
                #     vz,
                #     vzz,
                #     c,
                #     title="Converged solution",
                #     out="converged_solution.pdf",
                # )
                status = 1
                break
            v = v_new

        if status == 0:
            print(f"Failed to converge in {max_iter} iterations")

        return v_new, c, k, z


if __name__ == "__main__":
    # Example usage
    alpha = 0.3
    gamma = 2.0
    rho = 0.05
    delta = 0.05
    zbar = 1.0
    theta = 0.5
    sigma = 0.025

    rbc = RealBusinessCycleModel(
        alpha=alpha,
        gamma=gamma,
        rho=rho,
        delta=delta,
        zbar=zbar,
        theta=theta,
        sigma=sigma,
    )

    # Solve the model
    I, J = 1001, 1001
    start_time = time.time()
    v_new, c, k, z = rbc.solve(Delta=100, I=I, J=J, n_sigma=(30, 30))
    end_time = time.time()
    print(f"Time taken: {end_time - start_time:.6f} seconds", flush=True)

There is a lot going on in iter that possibly never changes between iterations.
edit: I am wrong, v is updated.

Did you try to add:

using AppleAccelerate

as first line?

On Intel or AMD you could try:

using MKL

Thank you for your advice. I use a Mac, so I tried using AppleAccelerate, but it didn’t provide much of a performance boost. Runtime is now around 85 seconds. Still slower than Python (~76s).

I’m using a Mac, so I’ve tried using AppleAccelerate , but it didn’t provide much of a performance boost—runtime is now around 85 seconds.

I tried it on my Linux laptop. Not as fast as your Mac:

julia> @time include("test.jl")
Converged in 11 iterations
Time taken: 133.213247 seconds
140.785799 seconds (104.44 M allocations: 54.964 GiB, 6.36% gc time, 4.19% compilation time: 4% of which was recompilation)

Somewhat faster with Julia 1.11, but not much:

julia> @time include("test.jl")
Converged in 11 iterations
Time taken: 121.605767 seconds
130.179831 seconds (19.58 M allocations: 53.607 GiB, 2.52% gc time, 4.94% compilation time: 12% of which was recompilation)

But interesting is the high number of allocations. You should try to reduce that, starting with a testcase for the most-inner loop first.

I think AppleAccelerate will not save us for sparse array operations.
Three random ideas which may not work:

  1. Quick printing reveals that the matrix A looks banded. Is that true? If so, you may be able to speed up the linear solve by encoding it with BandedMatrices.jl.

  2. Kronecker products are sometimes easier to represent lazily, as operators rather than fully-materialized matrices. If you go down that road, you may get speedups by switching to iterative linear solvers in LinearSolve.jl, provided your problem is well-conditioned.

  3. Kronecker products with a diagonal matrix are equivalent to BlockDiagonals.jl. EDIT: this is only true for one direction of the Kronecker product, which is non commutative. What we need here is probably BlockBandedMatrices.jl.

5 Likes

Another remark: we don’t need SparseArrays at all here, at least for the basic matrices (before Kronecker products): every one of them can be represented with Diagonal, Bidiagonal or Tridiagonal from LinearAlgebra

3 Likes

Your allocation of memory outside of the loop does not work.

Inside your loop you do not use the allocated memory. An example.

vk_F = Dk_F * v

This does not reuse vk_F. For that you would have to broadcast into it, with a .=:

vk_F .= Dk_F * v
2 Likes

But this is actually worse, since it allocates a new matrix anyways and then copies it over.

To reuse the memory use mul!:

mul!(vk_F, Dk_F, v)

There is definitely potential for lots of small optimizations around memory. However I think that these shouldn’t be the main issue considering that Python has those as well and OP wrote:

LinearSolve.jl allows you to choose a suitable algorithm for the inversion. Maybe the default is just not good for your problem? Have a look here and try one of the other algorithms:

2 Likes

Note that pre-allocation issues, while a favorite topic of the Julia community, are not very relevant here ^^ The vast majority of the time is spent in the linear solve, so the actual stake lies in speeding up that part. My hunch is that block arrays are the way to go.

10 Likes

can you try with PardisoJL() solver ? I get 4x perf with it, not sure if its only compatible with MKL though, you will need ]add Pardiso

It’s not on Apple M-series chips IIRC

Are you using a commercial version of Pardiso?

Thanks for commenting. Indeed I found the runtime for solving the sparse linear system takes ~95% of time. The execution time for other codes takes only around 2 seconds. So now the problem is why the default solver is so slow and how to make it faster.

You should first find out what solvers were being used in your benchmarks, and across your code in all languages tried. The simplest outcome is the fastest solver being an available option you just need to specify.

Set it up with LinearSolve.jl and try KLUFactorization, UMFPACKFactorization, PardisoFactorization (if it gives you a binary), and SparspakFactorization. Try this on a matrix with the same sparsity pattern outside of your main loop.

1 Like

It’s good that you have identified the main bottleneck, and that needs work. But you have a programming style that will cause performance issues in almost every code you write, namely staggeringly massive number of allocations, as well as unneeded and redundant type assertions, that make things even worse.

When you have identified a better solver, you should go back, remove every type assertion, and then start reducing the allocations.

For example:

k_grid::Vector{Float64} = range(k_min, k_max, length=I)

converts a lightweight, zero-allocations range into a heavy Vector, for no apparent reason.

I’ve never seen a case where it makes sense to fully allocate a grid.

It looks to me like almost none of the pre-allocated arrays in solve is ever used, or only used to create another array.

There’s far too much to analyze from a phone, but I strongly suspect that you could benefit from a less array-centric style, pre-allocating less and calculating more on-the-fly. And if you pre-allocate, make sure to actually reuse that memory.

How much this will benefit your current code, I cannot say, but it’s a general approach that will probably help you in the future.

All this after you’ve solved bottleneck number 1, of course :wink:

12 Likes

I tried all these. SparspakFactorization is fast. Thanks.

Thank you for your suggestions. But I don’t get why I should remove the type assertion. I have used Python (which does not need typing at all) and C++ (where typing is mandatory). For Julia, I guess typing is optional yet informs the compiler and helps performance? Very confused here.

For kk and zz, I actually used it for computing some stuff. But maybe a better implementation is to multiply it by one vector. Hence it becomes

kk::Matrix{Float64} = reshape(k_grid, (I, 1)) * ones(1, J)
zz::Matrix{Float64} = ones(I, 1) * reshape(z_grid, (1, J))

and then I vectorize it to one dimension.

I don’t have much programming practice experience, but I somehow know theoretical stuff. As these operations are O(I) and O(I^2) in time and memory and only take once, how much effect would it have on performance?