The following variation
using POMDPs
using POMDPModelTools
using QuickPOMDPs: QuickPOMDP
using TabularTDLearning
using POMDPPolicies
using POMDPModels
using Parameters, Random
struct MyMDP <: MDP{Int,Int}
indices::Dict{Int, Int}
MyMDP() = new(Dict{Int, Int}(0=>1,1=>2,2=>3,3=>4))
end
mdp = MyMDP()
POMDPs.actions(m::MyMDP) = [0,1,2,3]
POMDPs.states(m::MyMDP) = [0,1,2,3]
POMDPs.discount(m::MyMDP) = 0.95
POMDPs.stateindex(m::MyMDP, s) = m.indices[s]
POMDPs.actionindex(m::MyMDP, a) = m.indices[a]
POMDPs.initialstate(m::MyMDP) = Uniform([0,1,2,3])
function POMDPs.transition(m::MyMDP, s, a)
if s == 0 && a == 0
return SparseCat([0,1,2,3], [1,0,0,0])
elseif s == 0 && a == 1
return SparseCat([0,1,2,3], [0.7, 0.3,0,0])
elseif s == 0 && a == 2
return SparseCat([0,1,2,3], [0.2, 0.5,0.3,0])
elseif s == 0 && a == 3
return SparseCat([0,1,2,3], [0.2, 0.5,0.3,0])
elseif s == 1 && a == 0
return SparseCat([0,1,2,3], [0.7, 0.3,0,0])
elseif s == 1 && a == 1
return SparseCat([0,1,2,3], [0.2, 0.5,0.3,0])
elseif s == 1 && a == 2
return SparseCat([0,1,2,3], [0.2, 0.5,0.3,0])
elseif s == 1 && a == 3
return SparseCat([0,1,2,3], [0.2, 0.5,0.3,0])
elseif s == 2 && a == 0
return SparseCat([0,1,2,3], [0.2, 0.5,0.3,0])
elseif s == 2 && a == 1
return SparseCat([0,1,2,3], [0, 0.2,0.5,0.3])
elseif s == 2 && a == 2
return SparseCat([0,1,2,3], [0, 0.2,0.5,0.3])
elseif s == 2 && a == 3
return SparseCat([0,1,2,3], [0, 0.2,0.5,0.3])
elseif s == 3 && a == 0
return SparseCat([0,1,2,3], [0, 0.2,0.5,0.3])
elseif s == 3 && a == 1
return SparseCat([0,1,2,3], [0, 0.2,0.5,0.3])
elseif s == 3 && a == 2
return SparseCat([0,1,2,3], [0, 0.2,0.5,0.3])
elseif s == 3 && a == 3
return SparseCat([0,1,2,3], [0, 0.2,0.5,0.3])
else
@show "transition", s, a
return Uniform([0, 1, 2, 3])
end
end
function POMDPs.reward(m::MyMDP, s, a)
if s == 0 && a == 0
return -45
elseif s == 0 && a == 1
return -40
elseif s == 0 && a == 2
return -50
elseif s == 0 && a == 3
return -70
elseif s == 1 && a == 0
return -14
elseif s == 1 && a == 1
return -44
elseif s == 1 && a == 2
return -54
elseif s == 1 && a == 3
return -74
elseif s == 2 && a == 0
return -8
elseif s == 2 && a == 1
return -38
elseif s == 3 && a == 0
return -12
else
@show "reward", s, a
return 0
end
end
q_learning_solver = QLearningSolver(n_episodes=10,
learning_rate=0.8,
exploration_policy=EpsGreedyPolicy(mdp, 0.5),
verbose=false);
q_learning_policy = solve(q_learning_solver, mdp);
does something for me without showing errors. I don’t know if it is what you intended it to do.