my implementation works well for other environments but not working well at LunarLander-v2 (can’t reach to +200) it’s learning but doesn’t solve the env
but works well on other envs like CartPole, Acrobot, also a custom snake env
using Flux
using Flux: huber_loss, ADAM, gradient, params
using Zygote
using Distributions
using Gym
using Statistics: mean, std
using GameZero
using Colors
using BSON: @save, @load
# ------------------------ GAME STUFF -------------------------
HEIGHT = 320
WIDTH = 240
BACKGROUND = colorant"white"
BLOCK_SIZE = 20
DIRECTIONS = ["right", "left", "up", "down"]
# ----------------------- ENV ------------------------------
Base.@kwdef mutable struct SnakeGameEnv
snake::Vector{Rect} = [Rect(200, 200, BLOCK_SIZE, BLOCK_SIZE),
Rect(180, 200, BLOCK_SIZE, BLOCK_SIZE),
Rect(160, 200, BLOCK_SIZE, BLOCK_SIZE)]
food::Rect = Rect(rand(1:WIDTH), rand(1:HEIGHT), BLOCK_SIZE, BLOCK_SIZE)
record::Int = 0
score::Int = 0
done::Bool = false
reward::Float32 = 0.0
step::Int = 0
direction::String = "right"
end
function place_food(env::SnakeGameEnv)
x = rand(1:WIDTH - BLOCK_SIZE)
y = rand(1:HEIGHT - BLOCK_SIZE)
food = Rect(x, y, BLOCK_SIZE, BLOCK_SIZE)
if any(map(x -> collide(x, food), env.snake))
place_food(env)
else
return food
end
end
function check_game_over(head::Rect, env::SnakeGameEnv, check_danger::Bool)
if head.x > WIDTH - BLOCK_SIZE || head.x < 0 || head.y < 0 || head.y > HEIGHT - BLOCK_SIZE
if !check_danger
env.reward = -10
env.done = true
end
return true
end
if any(map(x -> collide(x, head), env.snake[3:end]))
if !check_danger
env.reward = -100
env.done = true
end
return true
end
if !check_danger && env.step > 100 * length(env.snake)
env.reward = -10
env.done = true
end
return false
end
function check_danger(head, size::Int, env)
x = head.x + size * BLOCK_SIZE
y = head.y + size * BLOCK_SIZE
new_head = Rect(x, y, BLOCK_SIZE, BLOCK_SIZE)
danger_right = env.direction == "right" && check_game_over(new_head, env, true) ? 1 : 0
danger_left = env.direction == "left" && check_game_over(new_head, env, true) ? 1 : 0
danger_up = env.direction == "up" && check_game_over(new_head, env, true) ? 1 : 0
danger_down = env.direction == "down" && check_game_over(new_head, env, true) ? 1 : 0
danger_right, danger_left, danger_up, danger_down
end
function get_state_1(env::SnakeGameEnv)
head = env.snake[1]
mid = env.snake[Int(floor(length(env.snake) / 2))]
tail = env.snake[end]
food = env.food
danger_right_1, danger_left_1, danger_up_1, danger_down_1 = check_danger(head, 1, env)
danger_right_2, danger_left_2, danger_up_2, danger_down_2 = check_danger(head, 2, env)
danger_right_3, danger_left_3, danger_up_3, danger_down_3 = check_danger(head, 3, env)
danger_right_4, danger_left_4, danger_up_4, danger_down_4 = check_danger(head, 4, env)
state = [danger_right_1, danger_left_1, danger_up_1, danger_down_1,
danger_right_2, danger_left_2, danger_up_2, danger_down_2,
danger_right_3, danger_left_3, danger_up_3, danger_down_3,
danger_right_4, danger_left_4, danger_up_4, danger_down_4]
head.x > food.x ? push!(state, 1) : push!(state, 0)
head.y > food.y ? push!(state, 1) : push!(state, 0)
tail.x > food.x ? push!(state, 1) : push!(state, 0)
tail.y > food.y ? push!(state, 1) : push!(state, 0)
mid.x > food.x ? push!(state, 1) : push!(state, 0)
mid.y > food.y ? push!(state, 1) : push!(state, 0)
head.x > tail.x ? push!(state, 1) : push!(state, 0)
head.y > tail.y ? push!(state, 1) : push!(state, 0)
head.x > mid.x ? push!(state, 1) : push!(state, 0)
head.y > mid.y ? push!(state, 1) : push!(state, 0)
mid.x > tail.x ? push!(state, 1) : push!(state, 0)
mid.y > tail.y ? push!(state, 1) : push!(state, 0)
return Float32.(state)
end
function get_state_2(env::SnakeGameEnv)
head = env.snake[1]
mid = env.snake[Int(floor(length(env.snake) / 2))]
tail = env.snake[end]
food = env.food
danger_right_1, danger_left_1, danger_up_1, danger_down_1 = check_danger(head, 1, env)
danger_right_2, danger_left_2, danger_up_2, danger_down_2 = check_danger(head, 2, env)
danger_right_3, danger_left_3, danger_up_3, danger_down_3 = check_danger(head, 3, env)
danger_right_4, danger_left_4, danger_up_4, danger_down_4 = check_danger(head, 4, env)
state = [danger_right_1, danger_left_1, danger_up_1, danger_down_1,
danger_right_2, danger_left_2, danger_up_2, danger_down_2,
danger_right_3, danger_left_3, danger_up_3, danger_down_3,
danger_right_4, danger_left_4, danger_up_4, danger_down_4,
head.x / WIDTH, head.y / HEIGHT,
food.x / WIDTH, food.y / HEIGHT,
tail.x / WIDTH, tail.y / HEIGHT,
mid.x / WIDTH, mid.y / HEIGHT,
findfirst(x -> env.direction == x, DIRECTIONS)[1] / 4,
length(env.snake) / WIDTH]
head.x > food.x ? push!(state, 1) : push!(state, 0)
head.y > food.y ? push!(state, 1) : push!(state, 0)
tail.x > food.x ? push!(state, 1) : push!(state, 0)
tail.y > food.y ? push!(state, 1) : push!(state, 0)
mid.x > food.x ? push!(state, 1) : push!(state, 0)
mid.y > food.y ? push!(state, 1) : push!(state, 0)
head.x > tail.x ? push!(state, 1) : push!(state, 0)
head.y > tail.y ? push!(state, 1) : push!(state, 0)
head.x > mid.x ? push!(state, 1) : push!(state, 0)
head.y > mid.y ? push!(state, 1) : push!(state, 0)
mid.x > tail.x ? push!(state, 1) : push!(state, 0)
mid.y > tail.y ? push!(state, 1) : push!(state, 0)
Float32.(state)
end
get_state(env::SnakeGameEnv) = get_state_2(env)
function reset_env(env::SnakeGameEnv)
env.snake = [Rect(200, 200, BLOCK_SIZE, BLOCK_SIZE),
Rect(180, 200, BLOCK_SIZE, BLOCK_SIZE),
Rect(160, 200, BLOCK_SIZE, BLOCK_SIZE)]
env.food = place_food(env)
env.score = 0
env.done = false
env.reward = 0
env.step = 0
env.direction = "right"
return get_state(env)
end
function move(env::SnakeGameEnv)
head = env.snake[1]
if env.direction == "right"
insert!(env.snake, 1, Rect(head.x + BLOCK_SIZE, head.y, BLOCK_SIZE, BLOCK_SIZE))
elseif env.direction == "left"
insert!(env.snake, 1, Rect(head.x - BLOCK_SIZE, head.y, BLOCK_SIZE, BLOCK_SIZE))
elseif env.direction == "up"
insert!(env.snake, 1, Rect(head.x, head.y - BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE))
else
insert!(env.snake, 1, Rect(head.x, head.y + BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE))
end
head = env.snake[1]
if collide(head, env.food)
env.score += 1
env.reward = 10
env.food = place_food(env)
else
pop!(env.snake)
end
end
function (env::SnakeGameEnv)(action)
env.step += 1
env.reward = 0
env.done = false
if action == 1 && env.direction != "left"
env.direction = "right"
elseif action == 2 && env.direction != "right"
env.direction = "left"
elseif action == 3 && env.direction != "down"
env.direction = "up"
elseif action == 4 && env.direction != "up"
env.direction = "down"
end
# update snake
move(env)
check_game_over(env.snake[1], env, false)
end
# --------------------- MEMORY -----------------------
mutable struct Memory
states::Vector{Vector{Float32}}
actions::Vector{Int}
rewards::Vector{Float32}
next_states::Vector{Vector{Float32}}
probs::Vector{Float32}
values::Vector{Float32}
dones::Vector{Bool}
end
function vec2mat(input_vec::Vector{Vector{T}}) where T
nrows = length(input_vec)
ncols = length(input_vec[1])
mat = zeros(T, ncols, nrows)
for i ∈ 1:nrows
mat[:, i] = input_vec[i]
end
mat
end
function clear_memory(m::Memory)
m.states = []
m.actions = []
m.rewards = []
m.next_states = []
m.probs = []
m.values = []
m.dones = []
end
# ------------------------ model -----------------------------
function build_model(input_dim, output_dim)
actor = Chain(Dense(input_dim, 128, tanh), Dense(128, output_dim), softmax)
critic = Chain(Dense(input_dim, 128, tanh), Dense(128, 1))
actor, critic
end
mutable struct Model
actor::Chain
critic::Chain
end
Flux.@functor Model
# -------------------- agent --------------------------------------
mutable struct PPO
policy::Model
γ::Float32
λ::Float32
clip::Float32
update_every::Int
update_step::Int
epochs::Int
opt
memory::Memory
end
function PPOAgent(state_dim, action_dim, lr, gamma, lmbda, clip, update_every, epochs)
update_step = 0
actor, critic = build_model(state_dim, action_dim)
policy = Model(actor, critic)
opt = ADAM(lr)
memory = Memory([], [], [], [], [], [], [])
PPO(
policy,
gamma,
lmbda,
clip,
update_every,
update_step,
epochs,
opt,
memory
)
end
function select_action(agent::PPO, state::Vector{Float32})
logit = agent.policy.actor(state)
dist = Categorical(logit)
value = agent.policy.critic(state)[1]
action = rand(dist)
action, value, logit
end
function take_step(agent::PPO, state, action, reward, next_state, done, value, prob)
push!(agent.memory.states, state)
push!(agent.memory.actions, action)
push!(agent.memory.rewards, reward)
push!(agent.memory.next_states, next_state)
push!(agent.memory.dones, done)
push!(agent.memory.values, value)
push!(agent.memory.probs, prob)
agent.update_step += 1
if agent.update_step % agent.update_every == 0
train(agent)
clear_memory(agent.memory)
end
end
function train(agent::PPO)
states = vec2mat(agent.memory.states)
next_states = vec2mat(agent.memory.next_states)
values = agent.memory.values
actions = CartesianIndex.(agent.memory.actions, 1:length(agent.memory.actions))
for _ ∈ 1:agent.epochs
ps = params(agent.policy)
gs = gradient(ps) do
v = agent.policy.critic(states) |> vec
v_next = agent.policy.critic(next_states) |> vec
v_target = agent.memory.rewards .+ agent.γ .* (1 .- agent.memory.dones) .* v_next
delta = v_target .- values
gae = 0
T = length(agent.memory.rewards)
advantages = zeros(Float32, T)
ignore() do
for i ∈ T:-1:1
gae = delta[i] + agent.γ * agent.λ * gae * (1 - agent.memory.dones[i])
advantages[i] = gae
end
#advantages = (advantages .- mean(advantages)) ./ (std(advantages) .+ 1e-5)
end
returns = advantages .+ values |> dropgrad
logits = agent.policy.actor(states)
log_probs = log.(logits[actions])
old_log_probs = log.(agent.memory.probs)
ratio = exp.(log_probs .- old_log_probs)
surr1 = ratio .* advantages
surr2 = clamp.(ratio, 1 - agent.clip, 1 + agent.clip) .* advantages
loss = -mean(min.(surr1, surr2)) + 0.5 * Flux.mse(returns, v)
loss
end
Flux.Optimise.update!(agent.opt, ps, gs)
end
end
# ------------------------------- train ----------------------------------
function train_loop_snake()
env = SnakeGameEnv()
state_dim = length(get_state(env))
action_dim = 4
lr = 0.0002
gamma = 0.99
lmbda = 0.95
clip = 0.1
update_every = 20
epochs = 10
agent = PPOAgent(state_dim, action_dim, lr, gamma, lmbda, clip, update_every, epochs)
total_reward = 0
for episode ∈ 1:100000
state = reset_env(env)
while !env.done
action, value, prob = select_action(agent, state)
env(action)
next_state = get_state(env)
total_reward += env.reward
take_step(agent, state, action, env.reward, next_state, env.done, value, prob[action])
state = next_state
if env.done
if env.score > env.record
env.record = env.score
actor_model = agent.policy.actor
@save "actor_model_snake.bson" actor_model
println("model improved, saved model!")
end
break
end
end
if episode % 20 == 0
avg_reward = total_reward / 20
total_reward = 0
@info "Episode : $episode | avg_reward : $avg_reward | record : $(env.record)"
end
end
agent.policy
end
function train_loop_gym()
env = GymEnv("CartPole-v1")
state_dim = env.observation_space.shape[1]
action_dim = env.action_space.n
lr = 0.0002
gamma = 0.99
lmbda = 0.95
clip = 0.1
update_every = 20
epochs = 10
agent = PPOAgent(state_dim, action_dim, lr, gamma, lmbda, clip, update_every, epochs)
total_reward = 0
for episode ∈ 1:10000
state = reset!(env)
for _ ∈ 1:1000
action, value, prob = select_action(agent, state)
next_state, reward, done, _ = step!(env, action - 1)
total_reward += reward
take_step(agent, state, action, reward, next_state, done, value, prob[action])
state = next_state
if done
break
end
end
if episode % 20 == 0
avg_reward = total_reward / 20
total_reward = 0
println("Episode : $episode | avg_reward : $avg_reward")
end
end
end
train_loop_gym()
in order to make Gym package work you need to modify some files
make sure you have installed gym latest version using pip
then install Gym in julia using add Gym
then go to julia/packages (where your all julia packages are installed) find Gym, go to src and edit env.jl like this
include("spaces.jl")
include("spec.jl")
mutable struct GymEnv
name
spec
action_space
observation_space
reward_range
gymenv
end
function GymEnv(id::String, continuous::Bool=false)
gymenv = nothing
try
try
gymenv = gym.make(id, continuous=continuous)
catch e
gymenv = gym.make(id)
end
catch e
error("Error received during the initialization of $id\n$e")
end
spec = Spec(gymenv.spec.id,
gymenv.spec.reward_threshold,
gymenv.spec.nondeterministic,
gymenv.spec.max_episode_steps,
)
action_space = julia_space(gymenv.action_space)
observation_space = julia_space(gymenv.observation_space)
env = GymEnv(id, spec, action_space,
observation_space, gymenv.reward_range, gymenv)
return env
end
reset!(env::GymEnv) = env.gymenv.reset()[1]
function render(env::GymEnv; mode="human")
env.gymenv.render(mode)
end
function step!(env::GymEnv, action)
ob, reward, done, information, _ = env.gymenv.step(action)
return ob, reward, done, information
end
close!(env::GymEnv) = env.gymenv.close()
seed!(env::GymEnv, seed=nothing) = env.gymenv.seed(seed)
then reload the Gym package( you can use build Gym)
to test it with LunarLander-v2, you can change env name in train_loop_gym() function
note that i can solve the LunarLander-v2 env with different implementation of PPO that uses Monte Carlo estimation instead of generalized advantage estimation used in the current implementation, if you need the code for Monte Carlo, i can provide of course