Hi @clrescobar,
it seems your question is too abstract. Could you provide us with a MWE to ease our task?
All the best…
Hi @clrescobar,
it seems your question is too abstract. Could you provide us with a MWE to ease our task?
All the best…
Sure!
using POMDPs
using POMDPModelTools
using TabularTDLearning
using POMDPPolicies
using POMDPModels
using Parameters, Random
using POMDPLinter
using POMDPSimulators
using Statistics
using MCTS
using D3Trees
struct MyMDP <: MDP{Int,String}
indicesstates::Dict{Int, Int}
indicesactions::Dict{String, Int}
end
indicesstates= Dict(0=>1,1=>2,2=>3)
indicesactions = Dict("Send 0"=>1,"Send 1"=>2,"Send 2"=>3)
mdp = MyMDP(indicesstates::Dict{Int64, Int64},indicesactions::Dict{String, Int64})
POMDPs.actions(m::MyMDP) = ["Send 0","Send 1","Send 2"]
POMDPs.states(m::MyMDP) = [0,1,2]
POMDPs.discount(m::MyMDP) = 0.95
POMDPs.stateindex(m::MyMDP, s) = m.indicesstates[s]
POMDPs.actionindex(m::MyMDP, a) = m.indicesactions[a]
POMDPs.initialstate(m::MyMDP) = Uniform([0,1,2])
function POMDPs.transition(m::MyMDP, s, a)
if s == 0 && a == "Send 0"
return SparseCat([0,1,2], [1,0,0])
elseif s == 0 && a == "Send 1"
return SparseCat([0,1,2], [0.7, 0.3,0])
elseif s == 0 && a == "Send 2"
return SparseCat([0,1,2], [0.2, 0.5,0.3])
elseif s == 1 && a == "Send 0"
return SparseCat([0,1,2], [0.7, 0.3,0])
elseif s == 1 && a == "Send 1"
return SparseCat([0,1,2], [0.2, 0.5,0.3])
elseif s == 1 && a == "Send 2"
return SparseCat([0,1,2], [0, 0.2,0.8]) #we consider having 2 and 3 left together
elseif s == 2 && a == "Send 0"
return SparseCat([0,1,2], [0.2, 0.5,0.3])
elseif s == 2 && a == "Send 1"
return SparseCat([0,1,2], [0, 0.2,0.8]) #we consider having 2 and 3 left together
elseif s == 2 && a == "Send 2"
return SparseCat([0,1,2], [0, 0, 1]) #we consider having 1, 2 and 3 left together
end
end
function POMDPs.reward(m::MyMDP, s, a)
if s == 0 && a == "Send 0"
return -45
elseif s == 0 && a == "Send 1"
return -40
elseif s == 0 && a == "Send 2"
return -50
elseif s == 1 && a == "Send 0"
return -14
elseif s == 1 && a == "Send 1"
return -34
elseif s == 1 && a == "Send 2"
return -54
elseif s == 2 && a == "Send 0"
return -8
elseif s == 2 && a == "Send 1"
return -38
elseif s == 2 && a == "Send 2"
return -58
end
end
n_iter = 100000
depth = 5
ec = 5.0 #In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
estimate_value = 0 #Function, object, or number used to estimate the value at the leaf nodes.
solver = MCTSSolver(n_iterations=n_iter,
depth=depth,
exploration_constant=ec,
estimate_value = estimate_value,
enable_tree_vis=true
)
planner = solve(solver, mdp)
state = rand(MersenneTwister(8), initialstate(mdp))
#Visualization
a, info = action_info(planner, state);
D3Tree(info[:tree], init_expand=1)
inchrome(D3Tree(planner, state))
Now it works, but when in
#Visualization
a, info = action_info(planner, state);
D3Tree(info[:tree], init_expand=1)
inchrome(D3Tree(planner, state))
I put
#Visualization
a, info = action_info(planner,0);
D3Tree(info[:tree], init_expand=1)
inchrome(D3Tree(planner, 0))
It says:
Error: attempt to access 3-element Vector{Vector{Int64}} at index [0]
Nice job. I tried to clarify possible confusions between states, actions and indices in the following way
using POMDPs
using POMDPModelTools
using TabularTDLearning
using POMDPPolicies
using POMDPModels
using Parameters, Random
using POMDPLinter
using POMDPSimulators
using Statistics
using MCTS
using D3Trees
struct MyMDP <: MDP{String,String}
indicesstates::Dict{String, Int}
indicesactions::Dict{String, Int}
end
indicesstates= Dict("State 0"=>1,"State 1"=>2,"State 2"=>3)
indicesactions = Dict("Action 0"=>1,"Action 1"=>2,"Action 2"=>3)
mdp = MyMDP(indicesstates::Dict{String, Int64},indicesactions::Dict{String, Int64})
POMDPs.actions(m::MyMDP) = ["Action 0","Action 1","Action 2"]
POMDPs.states(m::MyMDP) = ["State 0","State 1","State 2"]
POMDPs.discount(m::MyMDP) = 0.95
POMDPs.stateindex(m::MyMDP, s) = m.indicesstates[s]
POMDPs.actionindex(m::MyMDP, a) = m.indicesactions[a]
POMDPs.initialstate(m::MyMDP) = Uniform(states(m))
function POMDPs.transition(m::MyMDP, s, a)
if s == "State 0" && a == "Action 0"
return SparseCat(states(m), [1,0,0])
elseif s == "State 0" && a == "Action 1"
return SparseCat(states(m), [0.7, 0.3,0])
elseif s == "State 0" && a == "Action 2"
return SparseCat(states(m), [0.2, 0.5,0.3])
elseif s == "State 1" && a == "Action 0"
return SparseCat(states(m), [0.7, 0.3,0])
elseif s == "State 1" && a == "Action 1"
return SparseCat(states(m), [0.2, 0.5,0.3])
elseif s == "State 1" && a == "Action 2"
return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together
elseif s == "State 2" && a == "Action 0"
return SparseCat(states(m), [0.2, 0.5,0.3])
elseif s == "State 2" && a == "Action 1"
return SparseCat(states(m), [0, 0.2,0.8]) #we consider having 2 and 3 left together
elseif s == "State 2" && a == "Action 2"
return SparseCat(states(m), [0, 0, 1]) #we consider having 1, 2 and 3 left together
else
@assert false
end
end
function POMDPs.reward(m::MyMDP, s, a)
if s == "State 0" && a == "Action 0"
return -45
elseif s == "State 0" && a == "Action 1"
return -40
elseif s == "State 0" && a == "Action 2"
return -50
elseif s == "State 1" && a == "Action 0"
return -14
elseif s == "State 1" && a == "Action 1"
return -34
elseif s == "State 1" && a == "Action 2"
return -54
elseif s == "State 2" && a == "Action 0"
return -8
elseif s == "State 2" && a == "Action 1"
return -38
elseif s == "State 2" && a == "Action 2"
return -58
else
@assert false
end
end
n_iter = 100000
depth = 5
ec = 5.0 #In the UCB equation, Q + c*sqrt(log(t/N)), c is the exploration constant.
estimate_value = 0 #Function, object, or number used to estimate the value at the leaf nodes.
solver = MCTSSolver(n_iterations=n_iter,
depth=depth,
exploration_constant=ec,
estimate_value = estimate_value,
enable_tree_vis=true
)
planner = solve(solver, mdp)
# state = rand(MersenneTwister(8), initialstate(mdp))
for i in 1:3
state = states(mdp)[i]
#Visualization
a, info = action_info(planner, state);
d3tree = D3Tree(info[:tree], init_expand=1)
# inchrome(d3tree)
inbrowser(d3tree, "firefox")
end
and now the visualization works for all three cases.
Thank you so much!! You really help me! Where was the error?
Good question! I’m not even sure if the confusion is in your code? Could it be that POMDPs
have special conventions for states of type Int
(for example requiring it to be indices 1:n
)?
Hello,
You’re right, thank you so much!