Porting an example from QuantEcon.jl to POMDPs.jl

Just for fun (and to learn about POMDP) I am porting an example of dynamic programming from QuantEcon.jl to POMDPs.jl. The example is available here: https://lectures.quantecon.org/jl/discrete_dp.html#example-a-growth-model

My code for the implementation is:

using POMDPs
using POMDPModels
using POMDPModelTools
using DiscreteValueIteration

struct SimpleOG{TI <: Integer, T <: Real, TR <: AbstractArray{T}, TQ <: AbstractArray{T}} <: MDP{T, T}
    B :: TI
    M :: TI
    α :: T
    β :: T
    R :: TR
    Q :: TQ
end

function SimpleOG(; B::Integer = 10, M::Integer = 5, α::T = 0.5, β::T = 0.90) where {T <: Real}

    u(c) = c^α # utility function
    n = B + M + 1
    m = M + 1

    R = Matrix{T}(undef, n, m)
    Q = zeros(Float64, n, m, n)

    for a in 0:M
        Q[:, a + 1, (a:(a + B)) .+ 1] .= 1 / (B + 1)
        for s in 0:(B + M)
            R[s + 1, a + 1] = a <= s ? u(s - a) : -Inf
        end
    end

    return SimpleOG(B, M, α, β, R, Q)
end

POMDPs.states(simpleog::SimpleOG) = collect(0:(simpleog.M + simpleog.B))
POMDPs.n_states(simpleog::SimpleOG) = simpleog.B + simpleog.M + 1
POMDPs.stateindex(simpleog::SimpleOG, s) = Int(s) + 1

POMDPs.actions(simpleog::SimpleOG) = collect(0:simpleog.M)
POMDPs.n_actions(simpleog::SimpleOG) = simpleog.M + 1
POMDPs.actionindex(simpleog::SimpleOG, a) = Int(a) + 1

POMDPs.transition(simpleog::SimpleOG, s, a) = simpleog.Q[Int(a) + 1, Int(s) + 1, :]

POMDPs.reward(simpleog::SimpleOG, s, a) = simpleog.R[Int(s), Int(a)]
POMDPs.reward(simpleog::SimpleOG, s, a, sp) = reward(simpleog, s, a)

POMDPs.discount(simpleog::SimpleOG) = simpleog.β

g = SimpleOG()

When I try to solve the model I get:

julia> POMDPs.solve(g)
MethodError: no method matching solve(::SimpleOG{Int64,Float64,Array{Float64,2},Array{Float64,3}})
Closest candidates are:
  solve(!Matched::POMDPPolicies.FunctionSolver, !Matched::Union{MDP, POMDP}) at /Users/amrods/.julia/packages/POMDPPolicies/oW6ud/src/function.jl:23
  solve(!Matched::POMDPPolicies.RandomSolver, !Matched::Union{MDP, POMDP}) at /Users/amrods/.julia/packages/POMDPPolicies/oW6ud/src/random.jl:36
  solve(!Matched::POMDPPolicies.VectorSolver{A}, !Matched::MDP{S,A}) where {S, A} at /Users/amrods/.julia/packages/POMDPPolicies/oW6ud/src/vector.jl:23
  ...

Stacktrace:
 [1] top-level scope at In[14]:1

Can you help?

Hi @amrods, I see this question is now quite old - I apologize for missing it.

solve takes two arguments, a solver and an MDP or POMDP, so you should use something like solve(ValueIterationSolver(), g)

1 Like