# Monte Carlo Tree Search

Hi @clrescobar,

it seems your question is too abstract. Could you provide us with a MWE to ease our task?

All the best…

Sure!

``````using POMDPs
using POMDPModelTools
using TabularTDLearning
using POMDPPolicies
using POMDPModels
using Parameters, Random
using POMDPLinter
using POMDPSimulators
using Statistics
using MCTS
using D3Trees

struct MyMDP <: MDP{Int,String}
indicesstates::Dict{Int, Int}
indicesactions::Dict{String, Int}
end

indicesstates= Dict(0=>1,1=>2,2=>3)
indicesactions = Dict("Send 0"=>1,"Send 1"=>2,"Send 2"=>3)

mdp = MyMDP(indicesstates::Dict{Int64, Int64},indicesactions::Dict{String, Int64})

POMDPs.actions(m::MyMDP) = ["Send 0","Send 1","Send 2"]
POMDPs.states(m::MyMDP) = [0,1,2]
POMDPs.discount(m::MyMDP) = 0.95
POMDPs.stateindex(m::MyMDP, s) = m.indicesstates[s]
POMDPs.actionindex(m::MyMDP, a) = m.indicesactions[a]
POMDPs.initialstate(m::MyMDP) = Uniform([0,1,2])

function POMDPs.transition(m::MyMDP, s, a)
if s == 0 && a == "Send 0"
return SparseCat([0,1,2], [1,0,0])
elseif s == 0 && a == "Send 1"
return SparseCat([0,1,2], [0.7, 0.3,0])
elseif s == 0 && a == "Send 2"
return SparseCat([0,1,2], [0.2, 0.5,0.3])

elseif s == 1 && a == "Send 0"
return SparseCat([0,1,2], [0.7, 0.3,0])
elseif s == 1 && a == "Send 1"
return SparseCat([0,1,2], [0.2, 0.5,0.3])
elseif s == 1 && a == "Send 2"
return SparseCat([0,1,2], [0, 0.2,0.8]) #we consider having 2 and 3 left together

elseif s == 2 && a == "Send 0"
return SparseCat([0,1,2], [0.2, 0.5,0.3])
elseif s == 2 && a == "Send 1"
return SparseCat([0,1,2], [0, 0.2,0.8]) #we consider having 2 and 3 left together
elseif s == 2 && a == "Send 2"
return SparseCat([0,1,2], [0, 0, 1]) #we consider having 1, 2 and 3 left together
end
end

function POMDPs.reward(m::MyMDP, s, a)
if s == 0 && a == "Send 0"
return -45
elseif s == 0 && a == "Send 1"
return -40
elseif s == 0 && a == "Send 2"
return -50

elseif s == 1 && a == "Send 0"
return -14
elseif s == 1 && a == "Send 1"
return -34
elseif s == 1 && a == "Send 2"
return -54

elseif s == 2 && a == "Send 0"
return -8
elseif s == 2 && a == "Send 1"
return -38
elseif s == 2 && a == "Send 2"
return -58
end
end

n_iter = 100000
depth = 5
ec = 5.0 #In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
estimate_value = 0 #Function, object, or number used to estimate the value at the leaf nodes.

solver = MCTSSolver(n_iterations=n_iter,
depth=depth,
exploration_constant=ec,
estimate_value = estimate_value,
enable_tree_vis=true
)
planner = solve(solver, mdp)
state = rand(MersenneTwister(8), initialstate(mdp))

#Visualization
a, info = action_info(planner, state);
D3Tree(info[:tree], init_expand=1)
inchrome(D3Tree(planner, state))
``````

Now it works, but when in

``````#Visualization
a, info = action_info(planner, state);
D3Tree(info[:tree], init_expand=1)
inchrome(D3Tree(planner, state))
``````

I put

``````#Visualization
a, info = action_info(planner,0);
D3Tree(info[:tree], init_expand=1)
inchrome(D3Tree(planner, 0))
``````

It says:

Error: attempt to access 3-element Vector{Vector{Int64}} at index 

Nice job. I tried to clarify possible confusions between states, actions and indices in the following way

``````using POMDPs
using POMDPModelTools
using TabularTDLearning
using POMDPPolicies
using POMDPModels
using Parameters, Random
using POMDPLinter
using POMDPSimulators
using Statistics
using MCTS
using D3Trees

struct MyMDP <: MDP{String,String}
indicesstates::Dict{String, Int}
indicesactions::Dict{String, Int}
end

indicesstates= Dict("State 0"=>1,"State 1"=>2,"State 2"=>3)
indicesactions = Dict("Action 0"=>1,"Action 1"=>2,"Action 2"=>3)

mdp = MyMDP(indicesstates::Dict{String, Int64},indicesactions::Dict{String, Int64})

POMDPs.actions(m::MyMDP) = ["Action 0","Action 1","Action 2"]
POMDPs.states(m::MyMDP) = ["State 0","State 1","State 2"]
POMDPs.discount(m::MyMDP) = 0.95
POMDPs.stateindex(m::MyMDP, s) = m.indicesstates[s]
POMDPs.actionindex(m::MyMDP, a) = m.indicesactions[a]
POMDPs.initialstate(m::MyMDP) = Uniform(states(m))

function POMDPs.transition(m::MyMDP, s, a)
if s == "State 0" && a == "Action 0"
return SparseCat(states(m), [1,0,0])
elseif s == "State 0" && a == "Action 1"
return SparseCat(states(m), [0.7, 0.3,0])
elseif s == "State 0" && a == "Action 2"
return SparseCat(states(m), [0.2, 0.5,0.3])

elseif s == "State 1" && a == "Action 0"
return SparseCat(states(m), [0.7, 0.3,0])
elseif s == "State 1" && a == "Action 1"
return SparseCat(states(m), [0.2, 0.5,0.3])
elseif s == "State 1" && a == "Action 2"
return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together

elseif s == "State 2" && a == "Action 0"
return SparseCat(states(m), [0.2, 0.5,0.3])
elseif s == "State 2" && a == "Action 1"
return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together
elseif s == "State 2" && a == "Action 2"
return SparseCat(states(m), [0, 0, 1]) #we consider having 1, 2 and 3 left together
else
@assert false
end
end

function POMDPs.reward(m::MyMDP, s, a)
if s == "State 0" && a == "Action 0"
return -45
elseif s == "State 0" && a == "Action 1"
return -40
elseif s == "State 0" && a == "Action 2"
return -50

elseif s == "State 1" && a == "Action 0"
return -14
elseif s == "State 1" && a == "Action 1"
return -34
elseif s == "State 1" && a == "Action 2"
return -54

elseif s == "State 2" && a == "Action 0"
return -8
elseif s == "State 2" && a == "Action 1"
return -38
elseif s == "State 2" && a == "Action 2"
return -58
else
@assert false
end
end

n_iter = 100000
depth = 5
ec = 5.0 #In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
estimate_value = 0 #Function, object, or number used to estimate the value at the leaf nodes.

solver = MCTSSolver(n_iterations=n_iter,
depth=depth,
exploration_constant=ec,
estimate_value = estimate_value,
enable_tree_vis=true
)
planner = solve(solver, mdp)
# state = rand(MersenneTwister(8), initialstate(mdp))
for i in 1:3
state = states(mdp)[i]

#Visualization
a, info = action_info(planner, state);
d3tree = D3Tree(info[:tree], init_expand=1)
# inchrome(d3tree)
inbrowser(d3tree, "firefox")
end
``````

and now the visualization works for all three cases.

Thank you so much!! You really help me! Where was the error?

Good question! I’m not even sure if the confusion is in your code? Could it be that `POMDPs` have special conventions for states of type `Int` (for example requiring it to be indices `1:n`)?

Hello,

You’re right, thank you so much!