Solving HJB PDE using deep learning

Dear All,

sorry for bothering with a possibly stupid question.

I would like to ask for some advice about solving the Hamilton-Jacobi-Bellman equation using deep learning. I am trying to solve an HJB equation of a simple neoclassical growth model from macroeconomics. Its HJB equation is

\rho V(k_t) = \frac{c_t^{1-\gamma}}{1-\gamma} + \frac{\partial V}{\partial k} \{A k_t^\alpha - \delta k_t - c_t\}
Control variable in this problem is c_t \geq 0, characterized by following first-order condition.
c_t^{-\gamma} = \frac{\partial V}{\partial k}

Following paper by FernandΓ©z-Villaverde et. al (2020), I am trying to solve this HJB problem using deep learning. I parameterize the value function V(k_t) and consumption function C(k_t) using feed-forward networks designed to guarantee non-negativity of consumption, and use ADAM to minimize a loss function that includes squared error in HJB equation and FOC over a reasonable interval of state-space.
\mathcal{E}_{HJB} = \rho V(k_t) - \frac{c_t^{1-\gamma}}{1-\gamma} - \frac{\partial V}{\partial k} \{A k_t^\alpha - \delta k_t - c_t\}
\mathcal{E}_{FOC} = c_t^{-\gamma} - \frac{\partial V}{\partial k}
\mathcal{L} = \sum_i {\mathcal{E}_{HJB}}_i + \sum_i {\mathcal{E}_{FOC}}_i

While the minimization procedure converges to \approx 0 loss, instead of getting concave and increasing value function, I got something totally weird.

I checked the result for many different choices of activation functions in my networks, but the problem persists, in all these cases I still got this shape of the value function.

In the paper (page 19), the authors show, that minimization of this loss function composed of HJB and FOC error (should?) converges to a solution of accurate upwind finite-difference scheme. Unfortunately, I am getting something totally different. Should I use some boundary conditions? Unfortunately, I think that there aren’t natural boundary conditions in this problem, other than transversality conditions.

I would be really glad for any insight, especially from
@Tamas_Papp and @jlperla

Best,
Honza

Here is a MWE of my code.

#*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
#*-*-* Solve non-stochastic neoclassical growth model in continuous time *-*
#*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*

#(1) Import libraries
using Pkg
Pkg.add("Plots")
Pkg.add("Parameters")
Pkg.add("LinearAlgebra")
Pkg.add("Flux")
Pkg.add("Random")
Pkg.add("Distributions")
Pkg.add("ForwardDiff")
using Plots
using Parameters
using LinearAlgebra
using Flux
using Random
using Distributions
using ForwardDiff
include("FastADAM.jl")
Pkg.add("ZygoteRules")
using ZygoteRules
ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where T
  @assert length(ẋ) == 1
  ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,))
end
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where T =
  d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where T =
  d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),)

#(2) Build model structure
function uCRRA(c,Ξ³)
    if(Ξ³==1.0)
        return log(c)
    else
        return (c^(1-Ξ³))/(1-Ξ³)
    end
end
@with_kw struct NCG
    #(A) Structural parameters
    Ξ‘::Float32 = 0.5
    Ξ±::Float32 = 0.36
    Ξ²::Float32 = 1/(1+0.04)
    ρ::Float32 = 1/β-1
    Ξ³::Float32 = 2.0
    Ξ΄::Float32 = 0.05
    #(B) Steady-state
    kss::Float32 = (Ξ±/(1/Ξ²-(1-Ξ΄)))^(1/(1-Ξ±))
    kl::Float32 = 10^(-4)
    ku::Float32 = 10.0
    #(C) Grid
    Ο°::Int64 = 2500
    π“š::Array{Float32} = collect(range(kl,ku,length=Ο°))'
    #(D) Approximator
    T::Int64 = 8
    g::Function = swish
end
@with_kw struct Optimizer
  #(C) Networks and ADAM setting
    S1::Float32 = 0.0
    S2::Float32 = 1.0
    Ο°::Int64 = 2500
    Ο‘::Int64 = 150
    T::Int64 = 32
    Ξ“::Float32 = 0.001
    Ξ²1::Float32 = 0.9
    Ξ²2::Float32 = 0.99
    Ο΅::Float32 = 10^(-8)
    Ξ½::Float32 = 10^(-5)
    𝚰::Int64 = 5000000
    𝛀::Float32 = 0.1
    𝒯::Int64 = 1500
    𝓇::Float32 = 0.9
end

#Try swish,gelu,elu,mish,leakyrelu => all leads to the more or less same result
NC1 = NCG(g=swish)
@unpack Ξ‘,T,Ξ±,Ξ²,ρ,Ξ³,Ξ΄,kl,ku,π“š,g = NC1

#(3) Define approximation networks
#(3.1) Define policy network
𝓒 = Chain(Dense(1,T,g),Dense(T,T,g),
Dense(T,T,g),Dense(T,1,softplus))

#(3.3) Value network
π“₯ = Chain(Dense(1,T,g),Dense(T,T,g),
Dense(T,T,g),Dense(T,T,g),Dense(T,1,identity))
dπ“₯(t) = ForwardDiff.derivative(x->π“₯([x])[1],t)
𝚹 = Flux.params(𝓒,π“₯)

#(4) Build production function
f(x) = Ξ‘*x^Ξ± - Ξ΄*x

#(4) Build a loss function
function 𝓑(x)
    #HJB Error
    Ο΅_hjb = sum((ρ*π“₯(x) - uCRRA.(𝓒(x),Ξ³) - dπ“₯.(x).*(f.(x)-𝓒(x))).^2)
    #FOC Error
    Ο΅_foc = sum((𝓒(x).^(-Ξ³) - dπ“₯.(x)).^2)
    return Ο΅_hjb+Ο΅_foc
end

Ad1 = Optimizer(Ξ“=0.001)
@time 𝝝 = Adam(𝓑,𝚹,π“š,Ad1,1,"YES","NO")

plot(π“š',π“₯(π“š)',title="Value Function",xlabel="K_t",ylabel="V(K),",legend=false)
plot(π“š',dπ“₯.(π“š)',title="V_k(K)",xlabel="K_t",ylabel="V(K)",legend=false)
plot(π“š',𝓒(π“š)',title="Policy Function",xlabel="K_t",ylabel="C(K),",legend=false)

For replication, here is my implementation of the ADAM optimizer (FastADAM.jl file). There should’t be any problem, I used it sucessfully for many other problems.

function Adam(𝕱,𝚯,Data,Par,per,show,loc="NO")
    let
        #Initialize momentum estimates
        @unpack Ο‘,Ξ“,Ξ²1,Ξ²2,Ο΅,𝚰,𝛀,𝒯,𝓇 = Par
        n = length(𝚯)
        m = Array{Array{Float32}}(undef,n)
        v = Array{Array{Float32}}(undef,n)
        βˆ‡ = Array{Array{Float32}}(undef,n)
        hm = Array{Array{Float32}}(undef,n)
        hv = Array{Array{Float32}}(undef,n)
        u = Array{Array{Float32}}(undef,n)
        if(loc=="YES")
            𝚴 = Array{Array{Float32}}(undef,n)
            𝕻 = [Array{Array{Float32}}(undef,n)]
            𝕸 = Array{Float32}(undef,1)
            π–ˆ = Array{Int64}(undef,1)
            𝕸[1] = 𝕱(Data)
            π–ˆ[1] = 0
        end
        for j in 1:n
            m[j] = zeros(Float32,length(𝚯[j]))
            v[j] = zeros(Float32,length(𝚯[j]))
            βˆ‡[j] = zeros(Float32,length(𝚯[j]))
        end
        #Main training loop
        for i in 1:𝚰
            #Sample data for stochastic gradient
            𝓓,π“œ = size(Data)
            if(𝓓==1)
                xg = reshape(sample(Data,Ο‘),1,Ο‘)
            else
                π“˜ = sample(1:π“œ,Ο‘)
                xg = Data[:,π“˜]
            end
            #Take gradient
            βˆ‡π•± = Flux.gradient(()->𝕱(xg),𝚯)
            #Update gradient
            for j in 1:n
                #Check duality
                𝖙 = ForwardDiff.partials(βˆ‡π•±[𝚯[j]][1])
                if(length(𝖙)==1)
                    Ξ»(t) = ForwardDiff.value(ForwardDiff.partials(βˆ‡π•±[𝚯[j]][t]))[1]
                    βˆ‡[j] = Ξ».(1:length(𝚯[j]))
                else
                    βˆ‡[j] = Array{Float32,1}(βˆ‡π•±[𝚯[j]][:])
                end
                #Update momentum estimates
                m[j] = Ξ²1.*m[j] .+ (1-Ξ²1).*βˆ‡[j]
                v[j] = Ξ²2.*v[j] .+ (1-Ξ²2).*βˆ‡[j].^2
                hm[j] = m[j]./(1-Ξ²1^(i))
                hv[j] = v[j]./(1-Ξ²2^(i))
                #Update parameters
                u[j] = hm[j]./(sqrt.(hv[j]) .+ Ο΅)
                𝚯[j] .= 𝚯[j] .- reshape(Ξ“.*u[j],size(𝚯[j]))
                if(loc=="YES")
                    𝚴[j] = copy(𝚯[j] .- reshape(Ξ“.*u[j],size(𝚯[j])))
                end
            end
            #Compute loss
            if(i%per==0)
                𝕷 = 𝕱(Data)
                if(show=="YES")
                    #println(i)
                    println(𝕷)
                end
                #Secundary convergence check
                if(loc=="YES")
                    if(𝕷<𝓇*𝕸[1])
                        𝕸[1] = 𝕷
                        𝕻[1] = 𝚴
                        π–ˆ[1] = 0
                    else
                        π–ˆ[1] = π–ˆ[1] + 1
                    end
                    if(π–ˆ[1]>=𝒯)
                        for j in 1:length(𝚯)
                            𝚯[j] .= 𝚴[j]
                        end
                        println("Secundary convergence")
                        return 𝚴
                        break
                    end
                end
                if(𝕷<𝛀)
                    println("Convergence ",i)
                    break
                end
            end
        end
        #
        return 𝚯
    end
end

Sorry, I am not familiar with deep learning. As you mention, there aren’t natural boundary conditions for these kind of problems.

In practice, I find that a good initial guess matters a lot. I usually find

@article{den2015exact,
  title={Exact present solution with consistent future approximation: A gridless algorithm to solve stochastic dynamic models},
  author={Den Haan, Wouter J and Kobielarz, Michal L and Rendahl, Pontus},
  year={2015},
  publisher={Centre For Macroeconomics}
}

useful for that purpose. I would start from this and then monitor the solver.

1 Like

@Tamas_Papp Thank you, I will try that! So, from your experience Chebyshev/Spline collocation tend to converge to the right solution of HJB even without imposing boundary conditions?

Yes. Again, with a good initial guess. That’s something I would invest in.

1 Like