Monte Carlo Tree Search

Nice job. I tried to clarify possible confusions between states, actions and indices in the following way

using POMDPs
using POMDPModelTools
using TabularTDLearning
using POMDPPolicies
using POMDPModels
using Parameters, Random
using POMDPLinter
using POMDPSimulators
using Statistics
using MCTS
using D3Trees

struct MyMDP <: MDP{String,String} 
    indicesstates::Dict{String, Int}
    indicesactions::Dict{String, Int}
end

indicesstates= Dict("State 0"=>1,"State 1"=>2,"State 2"=>3)
indicesactions = Dict("Action 0"=>1,"Action 1"=>2,"Action 2"=>3)

mdp = MyMDP(indicesstates::Dict{String, Int64},indicesactions::Dict{String, Int64})

POMDPs.actions(m::MyMDP) = ["Action 0","Action 1","Action 2"]
POMDPs.states(m::MyMDP) = ["State 0","State 1","State 2"]
POMDPs.discount(m::MyMDP) = 0.95
POMDPs.stateindex(m::MyMDP, s) = m.indicesstates[s]
POMDPs.actionindex(m::MyMDP, a) = m.indicesactions[a]
POMDPs.initialstate(m::MyMDP) = Uniform(states(m))

function POMDPs.transition(m::MyMDP, s, a)
    if s == "State 0" && a == "Action 0" 
        return SparseCat(states(m), [1,0,0])
    elseif s == "State 0" && a == "Action 1"
        return SparseCat(states(m), [0.7, 0.3,0])
    elseif s == "State 0" && a == "Action 2"
        return SparseCat(states(m), [0.2, 0.5,0.3])

    elseif s == "State 1" && a == "Action 0"
        return SparseCat(states(m), [0.7, 0.3,0])
    elseif s == "State 1" && a == "Action 1"
        return SparseCat(states(m), [0.2, 0.5,0.3])
    elseif s == "State 1" && a == "Action 2" 
        return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together

    elseif s == "State 2" && a == "Action 0" 
        return SparseCat(states(m), [0.2, 0.5,0.3])
    elseif s == "State 2" && a == "Action 1"
        return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together
    elseif s == "State 2" && a == "Action 2"
        return SparseCat(states(m), [0, 0, 1]) #we consider having 1, 2 and 3 left together
    else
        @assert false
    end    
end

function POMDPs.reward(m::MyMDP, s, a)
    if s == "State 0" && a == "Action 0" 
        return -45
    elseif s == "State 0" && a == "Action 1"
        return -40
    elseif s == "State 0" && a == "Action 2"
        return -50
   
    elseif s == "State 1" && a == "Action 0" 
        return -14
    elseif s == "State 1" && a == "Action 1"
        return -34
    elseif s == "State 1" && a == "Action 2"
        return -54
      
    elseif s == "State 2" && a == "Action 0" 
        return -8
    elseif s == "State 2" && a == "Action 1"
        return -38
    elseif s == "State 2" && a == "Action 2"
        return -58
    else
        @assert false
    end        
end 

n_iter = 100000
depth = 5
ec = 5.0 #In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
estimate_value = 0 #Function, object, or number used to estimate the value at the leaf nodes.
    
solver = MCTSSolver(n_iterations=n_iter,
    depth=depth,
    exploration_constant=ec,
    estimate_value = estimate_value,
    enable_tree_vis=true
)
planner = solve(solver, mdp)
# state = rand(MersenneTwister(8), initialstate(mdp))
for i in 1:3
    state = states(mdp)[i]

    #Visualization
    a, info = action_info(planner, state);
    d3tree = D3Tree(info[:tree], init_expand=1) 
    # inchrome(d3tree)
    inbrowser(d3tree, "firefox")
end

and now the visualization works for all three cases.

2 Likes