Nice job. I tried to clarify possible confusions between states, actions and indices in the following way
using POMDPs
using POMDPModelTools
using TabularTDLearning
using POMDPPolicies
using POMDPModels
using Parameters, Random
using POMDPLinter
using POMDPSimulators
using Statistics
using MCTS
using D3Trees
struct MyMDP <: MDP{String,String}
indicesstates::Dict{String, Int}
indicesactions::Dict{String, Int}
end
indicesstates= Dict("State 0"=>1,"State 1"=>2,"State 2"=>3)
indicesactions = Dict("Action 0"=>1,"Action 1"=>2,"Action 2"=>3)
mdp = MyMDP(indicesstates::Dict{String, Int64},indicesactions::Dict{String, Int64})
POMDPs.actions(m::MyMDP) = ["Action 0","Action 1","Action 2"]
POMDPs.states(m::MyMDP) = ["State 0","State 1","State 2"]
POMDPs.discount(m::MyMDP) = 0.95
POMDPs.stateindex(m::MyMDP, s) = m.indicesstates[s]
POMDPs.actionindex(m::MyMDP, a) = m.indicesactions[a]
POMDPs.initialstate(m::MyMDP) = Uniform(states(m))
function POMDPs.transition(m::MyMDP, s, a)
if s == "State 0" && a == "Action 0"
return SparseCat(states(m), [1,0,0])
elseif s == "State 0" && a == "Action 1"
return SparseCat(states(m), [0.7, 0.3,0])
elseif s == "State 0" && a == "Action 2"
return SparseCat(states(m), [0.2, 0.5,0.3])
elseif s == "State 1" && a == "Action 0"
return SparseCat(states(m), [0.7, 0.3,0])
elseif s == "State 1" && a == "Action 1"
return SparseCat(states(m), [0.2, 0.5,0.3])
elseif s == "State 1" && a == "Action 2"
return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together
elseif s == "State 2" && a == "Action 0"
return SparseCat(states(m), [0.2, 0.5,0.3])
elseif s == "State 2" && a == "Action 1"
return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together
elseif s == "State 2" && a == "Action 2"
return SparseCat(states(m), [0, 0, 1]) #we consider having 1, 2 and 3 left together
else
@assert false
end
end
function POMDPs.reward(m::MyMDP, s, a)
if s == "State 0" && a == "Action 0"
return -45
elseif s == "State 0" && a == "Action 1"
return -40
elseif s == "State 0" && a == "Action 2"
return -50
elseif s == "State 1" && a == "Action 0"
return -14
elseif s == "State 1" && a == "Action 1"
return -34
elseif s == "State 1" && a == "Action 2"
return -54
elseif s == "State 2" && a == "Action 0"
return -8
elseif s == "State 2" && a == "Action 1"
return -38
elseif s == "State 2" && a == "Action 2"
return -58
else
@assert false
end
end
n_iter = 100000
depth = 5
ec = 5.0 #In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
estimate_value = 0 #Function, object, or number used to estimate the value at the leaf nodes.
solver = MCTSSolver(n_iterations=n_iter,
depth=depth,
exploration_constant=ec,
estimate_value = estimate_value,
enable_tree_vis=true
)
planner = solve(solver, mdp)
# state = rand(MersenneTwister(8), initialstate(mdp))
for i in 1:3
state = states(mdp)[i]
#Visualization
a, info = action_info(planner, state);
d3tree = D3Tree(info[:tree], init_expand=1)
# inchrome(d3tree)
inbrowser(d3tree, "firefox")
end
and now the visualization works for all three cases.