Monte Carlo Tree Search

Hi @clrescobar,

it seems your question is too abstract. Could you provide us with a MWE to ease our task?

All the best…

Sure!

using POMDPs
using POMDPModelTools
using TabularTDLearning
using POMDPPolicies
using POMDPModels
using Parameters, Random
using POMDPLinter
using POMDPSimulators
using Statistics
using MCTS
using D3Trees

struct MyMDP <: MDP{Int,String} 
    indicesstates::Dict{Int, Int}
    indicesactions::Dict{String, Int}
end

indicesstates= Dict(0=>1,1=>2,2=>3)
indicesactions = Dict("Send 0"=>1,"Send 1"=>2,"Send 2"=>3)

mdp = MyMDP(indicesstates::Dict{Int64, Int64},indicesactions::Dict{String, Int64})

POMDPs.actions(m::MyMDP) = ["Send 0","Send 1","Send 2"]
POMDPs.states(m::MyMDP) = [0,1,2]
POMDPs.discount(m::MyMDP) = 0.95
POMDPs.stateindex(m::MyMDP, s) = m.indicesstates[s]
POMDPs.actionindex(m::MyMDP, a) = m.indicesactions[a]
POMDPs.initialstate(m::MyMDP) = Uniform([0,1,2])

function POMDPs.transition(m::MyMDP, s, a)
    if s == 0 && a == "Send 0" 
        return SparseCat([0,1,2], [1,0,0])
    elseif s == 0 && a == "Send 1"
        return SparseCat([0,1,2], [0.7, 0.3,0])
    elseif s == 0 && a == "Send 2"
        return SparseCat([0,1,2], [0.2, 0.5,0.3])

    elseif s == 1 && a == "Send 0"
        return SparseCat([0,1,2], [0.7, 0.3,0])
    elseif s == 1 && a == "Send 1"
        return SparseCat([0,1,2], [0.2, 0.5,0.3])
    elseif s == 1 && a == "Send 2" 
        return SparseCat([0,1,2], [0, 0.2,0.8]) #we consider having 2 and 3 left together

        
    elseif s == 2 && a == "Send 0" 
        return SparseCat([0,1,2], [0.2, 0.5,0.3])
    elseif s == 2 && a == "Send 1"
        return SparseCat([0,1,2], [0, 0.2,0.8]) #we consider having 2 and 3 left together
    elseif s == 2 && a == "Send 2"
        return SparseCat([0,1,2], [0, 0, 1]) #we consider having 1, 2 and 3 left together
    end    
end

function POMDPs.reward(m::MyMDP, s, a)
    if s == 0 && a == "Send 0" 
        return -45
    elseif s == 0 && a == "Send 1"
        return -40
    elseif s == 0 && a == "Send 2"
        return -50
   
    elseif s == 1 && a == "Send 0" 
        return -14
    elseif s == 1 && a == "Send 1"
        return -34
    elseif s == 1 && a == "Send 2"
        return -54
      
    elseif s == 2 && a == "Send 0" 
        return -8
    elseif s == 2 && a == "Send 1"
        return -38
    elseif s == 2 && a == "Send 2"
        return -58
    end        
end 

n_iter = 100000
depth = 5
ec = 5.0 #In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
estimate_value = 0 #Function, object, or number used to estimate the value at the leaf nodes.
    
solver = MCTSSolver(n_iterations=n_iter,
    depth=depth,
    exploration_constant=ec,
    estimate_value = estimate_value,
    enable_tree_vis=true
)
planner = solve(solver, mdp)
state = rand(MersenneTwister(8), initialstate(mdp))

#Visualization
a, info = action_info(planner, state);
D3Tree(info[:tree], init_expand=1) 
inchrome(D3Tree(planner, state))

Now it works, but when in

#Visualization
a, info = action_info(planner, state);
D3Tree(info[:tree], init_expand=1) 
inchrome(D3Tree(planner, state))

I put

#Visualization
a, info = action_info(planner,0);
D3Tree(info[:tree], init_expand=1) 
inchrome(D3Tree(planner, 0))

It says:

Error: attempt to access 3-element Vector{Vector{Int64}} at index [0]

Nice job. I tried to clarify possible confusions between states, actions and indices in the following way

using POMDPs
using POMDPModelTools
using TabularTDLearning
using POMDPPolicies
using POMDPModels
using Parameters, Random
using POMDPLinter
using POMDPSimulators
using Statistics
using MCTS
using D3Trees

struct MyMDP <: MDP{String,String} 
    indicesstates::Dict{String, Int}
    indicesactions::Dict{String, Int}
end

indicesstates= Dict("State 0"=>1,"State 1"=>2,"State 2"=>3)
indicesactions = Dict("Action 0"=>1,"Action 1"=>2,"Action 2"=>3)

mdp = MyMDP(indicesstates::Dict{String, Int64},indicesactions::Dict{String, Int64})

POMDPs.actions(m::MyMDP) = ["Action 0","Action 1","Action 2"]
POMDPs.states(m::MyMDP) = ["State 0","State 1","State 2"]
POMDPs.discount(m::MyMDP) = 0.95
POMDPs.stateindex(m::MyMDP, s) = m.indicesstates[s]
POMDPs.actionindex(m::MyMDP, a) = m.indicesactions[a]
POMDPs.initialstate(m::MyMDP) = Uniform(states(m))

function POMDPs.transition(m::MyMDP, s, a)
    if s == "State 0" && a == "Action 0" 
        return SparseCat(states(m), [1,0,0])
    elseif s == "State 0" && a == "Action 1"
        return SparseCat(states(m), [0.7, 0.3,0])
    elseif s == "State 0" && a == "Action 2"
        return SparseCat(states(m), [0.2, 0.5,0.3])

    elseif s == "State 1" && a == "Action 0"
        return SparseCat(states(m), [0.7, 0.3,0])
    elseif s == "State 1" && a == "Action 1"
        return SparseCat(states(m), [0.2, 0.5,0.3])
    elseif s == "State 1" && a == "Action 2" 
        return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together

    elseif s == "State 2" && a == "Action 0" 
        return SparseCat(states(m), [0.2, 0.5,0.3])
    elseif s == "State 2" && a == "Action 1"
        return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together
    elseif s == "State 2" && a == "Action 2"
        return SparseCat(states(m), [0, 0, 1]) #we consider having 1, 2 and 3 left together
    else
        @assert false
    end    
end

function POMDPs.reward(m::MyMDP, s, a)
    if s == "State 0" && a == "Action 0" 
        return -45
    elseif s == "State 0" && a == "Action 1"
        return -40
    elseif s == "State 0" && a == "Action 2"
        return -50
   
    elseif s == "State 1" && a == "Action 0" 
        return -14
    elseif s == "State 1" && a == "Action 1"
        return -34
    elseif s == "State 1" && a == "Action 2"
        return -54
      
    elseif s == "State 2" && a == "Action 0" 
        return -8
    elseif s == "State 2" && a == "Action 1"
        return -38
    elseif s == "State 2" && a == "Action 2"
        return -58
    else
        @assert false
    end        
end 

n_iter = 100000
depth = 5
ec = 5.0 #In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
estimate_value = 0 #Function, object, or number used to estimate the value at the leaf nodes.
    
solver = MCTSSolver(n_iterations=n_iter,
    depth=depth,
    exploration_constant=ec,
    estimate_value = estimate_value,
    enable_tree_vis=true
)
planner = solve(solver, mdp)
# state = rand(MersenneTwister(8), initialstate(mdp))
for i in 1:3
    state = states(mdp)[i]

    #Visualization
    a, info = action_info(planner, state);
    d3tree = D3Tree(info[:tree], init_expand=1) 
    # inchrome(d3tree)
    inbrowser(d3tree, "firefox")
end

and now the visualization works for all three cases.

2 Likes

Thank you so much!! You really help me! Where was the error?

Good question! I’m not even sure if the confusion is in your code? Could it be that POMDPs have special conventions for states of type Int (for example requiring it to be indices 1:n)?

Hello,

You’re right, thank you so much!