Help wanted with ppo implementation in flux.jl

my implementation works well for other environments but not working well at LunarLander-v2 (can’t reach to +200) it’s learning but doesn’t solve the env
but works well on other envs like CartPole, Acrobot, also a custom snake env

using Flux
using Flux: huber_loss, ADAM, gradient, params
using Zygote
using Distributions
using Gym
using Statistics: mean, std

using GameZero
using Colors

using BSON: @save, @load
# ------------------------ GAME STUFF -------------------------
HEIGHT = 320
WIDTH = 240
BACKGROUND = colorant"white"
BLOCK_SIZE = 20
DIRECTIONS = ["right", "left", "up", "down"]


# ----------------------- ENV ------------------------------
Base.@kwdef mutable struct SnakeGameEnv
    snake::Vector{Rect} = [Rect(200, 200, BLOCK_SIZE, BLOCK_SIZE), 
                           Rect(180, 200, BLOCK_SIZE, BLOCK_SIZE),
                           Rect(160, 200, BLOCK_SIZE, BLOCK_SIZE)]
    food::Rect = Rect(rand(1:WIDTH), rand(1:HEIGHT), BLOCK_SIZE, BLOCK_SIZE)
    record::Int = 0
    score::Int = 0
    done::Bool = false
    reward::Float32 = 0.0
    step::Int = 0
    direction::String = "right"
end

function place_food(env::SnakeGameEnv)
    x = rand(1:WIDTH - BLOCK_SIZE)
    y = rand(1:HEIGHT - BLOCK_SIZE)
    food = Rect(x, y, BLOCK_SIZE, BLOCK_SIZE)
    if any(map(x -> collide(x, food), env.snake))
        place_food(env)
    else
        return food
    end
    
end


function check_game_over(head::Rect, env::SnakeGameEnv, check_danger::Bool)
    if head.x > WIDTH - BLOCK_SIZE || head.x < 0 || head.y < 0 || head.y > HEIGHT - BLOCK_SIZE
        if !check_danger
            env.reward = -10
            env.done = true
        end
        return true
    end

    if any(map(x -> collide(x, head), env.snake[3:end]))
        if !check_danger
            env.reward = -100
            env.done = true
        end
        return true
    end

    if !check_danger && env.step > 100 * length(env.snake)
        env.reward = -10
        env.done = true
    end

    return false

end

function check_danger(head, size::Int, env)
    x = head.x + size * BLOCK_SIZE
    y = head.y + size * BLOCK_SIZE
    new_head = Rect(x, y, BLOCK_SIZE, BLOCK_SIZE)
    danger_right = env.direction == "right" && check_game_over(new_head, env, true) ? 1 : 0
    danger_left = env.direction == "left" && check_game_over(new_head, env, true) ? 1 : 0
    danger_up = env.direction == "up" && check_game_over(new_head, env, true) ? 1 : 0
    danger_down = env.direction == "down" && check_game_over(new_head, env, true) ? 1 : 0
    danger_right, danger_left, danger_up, danger_down

end

function get_state_1(env::SnakeGameEnv)
    head = env.snake[1]
    mid = env.snake[Int(floor(length(env.snake) / 2))]
    tail = env.snake[end]
    food = env.food

    danger_right_1, danger_left_1, danger_up_1, danger_down_1 = check_danger(head, 1, env)
    danger_right_2, danger_left_2, danger_up_2, danger_down_2 = check_danger(head, 2, env)
    danger_right_3, danger_left_3, danger_up_3, danger_down_3 = check_danger(head, 3, env)
    danger_right_4, danger_left_4, danger_up_4, danger_down_4 = check_danger(head, 4, env)

    state = [danger_right_1, danger_left_1, danger_up_1, danger_down_1,
             danger_right_2, danger_left_2, danger_up_2, danger_down_2,
             danger_right_3, danger_left_3, danger_up_3, danger_down_3,
             danger_right_4, danger_left_4, danger_up_4, danger_down_4]
    
    head.x > food.x ? push!(state, 1) : push!(state, 0)
    head.y > food.y ? push!(state, 1) : push!(state, 0)
    tail.x > food.x ? push!(state, 1) : push!(state, 0)
    tail.y > food.y ? push!(state, 1) : push!(state, 0)
    mid.x > food.x  ? push!(state, 1) : push!(state, 0)
    mid.y > food.y  ? push!(state, 1) : push!(state, 0)
    head.x > tail.x ? push!(state, 1) : push!(state, 0)
    head.y > tail.y ? push!(state, 1) : push!(state, 0)
    head.x > mid.x  ? push!(state, 1) : push!(state, 0)
    head.y > mid.y  ? push!(state, 1) : push!(state, 0)
    mid.x > tail.x  ? push!(state, 1) : push!(state, 0)
    mid.y > tail.y  ? push!(state, 1) : push!(state, 0)

    return Float32.(state)
              
end

function get_state_2(env::SnakeGameEnv)
    head = env.snake[1]
    mid = env.snake[Int(floor(length(env.snake) / 2))]
    tail = env.snake[end]
    food = env.food

    danger_right_1, danger_left_1, danger_up_1, danger_down_1 = check_danger(head, 1, env)
    danger_right_2, danger_left_2, danger_up_2, danger_down_2 = check_danger(head, 2, env)
    danger_right_3, danger_left_3, danger_up_3, danger_down_3 = check_danger(head, 3, env)
    danger_right_4, danger_left_4, danger_up_4, danger_down_4 = check_danger(head, 4, env)
    state = [danger_right_1, danger_left_1, danger_up_1, danger_down_1,
             danger_right_2, danger_left_2, danger_up_2, danger_down_2,
             danger_right_3, danger_left_3, danger_up_3, danger_down_3,
             danger_right_4, danger_left_4, danger_up_4, danger_down_4,
             head.x / WIDTH, head.y / HEIGHT,
             food.x / WIDTH, food.y / HEIGHT,
             tail.x / WIDTH, tail.y / HEIGHT,
             mid.x / WIDTH, mid.y / HEIGHT,
             findfirst(x -> env.direction == x, DIRECTIONS)[1] / 4,
             length(env.snake) / WIDTH]

    
    head.x > food.x ? push!(state, 1) : push!(state, 0)
    head.y > food.y ? push!(state, 1) : push!(state, 0)
    tail.x > food.x ? push!(state, 1) : push!(state, 0)
    tail.y > food.y ? push!(state, 1) : push!(state, 0)
    mid.x > food.x  ? push!(state, 1) : push!(state, 0)
    mid.y > food.y  ? push!(state, 1) : push!(state, 0)
    head.x > tail.x ? push!(state, 1) : push!(state, 0)
    head.y > tail.y ? push!(state, 1) : push!(state, 0)
    head.x > mid.x  ? push!(state, 1) : push!(state, 0)
    head.y > mid.y  ? push!(state, 1) : push!(state, 0)
    mid.x > tail.x  ? push!(state, 1) : push!(state, 0)
    mid.y > tail.y  ? push!(state, 1) : push!(state, 0)
    
    Float32.(state)
end

get_state(env::SnakeGameEnv) = get_state_2(env)

function reset_env(env::SnakeGameEnv)
    env.snake = [Rect(200, 200, BLOCK_SIZE, BLOCK_SIZE), 
                 Rect(180, 200, BLOCK_SIZE, BLOCK_SIZE),
                 Rect(160, 200, BLOCK_SIZE, BLOCK_SIZE)]
    
    env.food = place_food(env)
    env.score = 0
    env.done = false
    env.reward = 0
    env.step = 0 
    env.direction = "right"
    return get_state(env)
        
end



function move(env::SnakeGameEnv)
   
    head = env.snake[1]
    if env.direction == "right"
        insert!(env.snake, 1, Rect(head.x + BLOCK_SIZE, head.y, BLOCK_SIZE, BLOCK_SIZE))
    
    elseif env.direction == "left"
        insert!(env.snake, 1, Rect(head.x - BLOCK_SIZE, head.y, BLOCK_SIZE, BLOCK_SIZE))
    
    elseif env.direction == "up"
        insert!(env.snake, 1, Rect(head.x, head.y - BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE))
    
    else
        insert!(env.snake, 1, Rect(head.x, head.y + BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE))
    end

    head = env.snake[1]
    if collide(head, env.food)
        env.score += 1
        env.reward = 10
        env.food = place_food(env)
    else
        pop!(env.snake)
    end
end

function (env::SnakeGameEnv)(action)
    env.step += 1
    
    env.reward = 0
    env.done = false
    if action == 1 && env.direction != "left"
        env.direction = "right"
    
    elseif action == 2 && env.direction != "right"
        env.direction = "left"
    
    elseif action == 3 && env.direction != "down"
        env.direction = "up"
    
    elseif action == 4 && env.direction != "up"
        env.direction = "down"
    end

    # update snake
    move(env)
    check_game_over(env.snake[1], env, false)
    
    
end

# --------------------- MEMORY -----------------------
mutable struct Memory
    states::Vector{Vector{Float32}}
    actions::Vector{Int}
    rewards::Vector{Float32}
    next_states::Vector{Vector{Float32}}
    probs::Vector{Float32}
    values::Vector{Float32}
    dones::Vector{Bool}
end

function vec2mat(input_vec::Vector{Vector{T}}) where T
    nrows = length(input_vec)
    ncols = length(input_vec[1])
    mat   = zeros(T, ncols, nrows)
    for i ∈ 1:nrows
        mat[:, i] = input_vec[i]
    end
    mat
end

function clear_memory(m::Memory)
    m.states      = []
    m.actions     = []
    m.rewards     = []
    m.next_states = []
    m.probs       = []
    m.values      = []
    m.dones       = []
end

# ------------------------ model -----------------------------
function build_model(input_dim, output_dim)
    actor  = Chain(Dense(input_dim, 128, tanh), Dense(128, output_dim), softmax)
    critic = Chain(Dense(input_dim, 128, tanh), Dense(128, 1))
    actor, critic
end

mutable struct Model
    actor::Chain
    critic::Chain
end

Flux.@functor Model

# -------------------- agent --------------------------------------
mutable struct PPO
    policy::Model
    γ::Float32
    λ::Float32
    clip::Float32
    update_every::Int
    update_step::Int
    epochs::Int
    opt
    memory::Memory
end

function PPOAgent(state_dim, action_dim, lr, gamma, lmbda, clip, update_every, epochs)
    update_step   = 0
    actor, critic = build_model(state_dim, action_dim)
    policy        = Model(actor, critic)
    opt           = ADAM(lr)
    memory        = Memory([], [], [], [], [], [], [])

    PPO(
        policy,
        gamma,
        lmbda,
        clip,
        update_every,
        update_step,
        epochs,
        opt,
        memory
    )
end

function select_action(agent::PPO, state::Vector{Float32})
    logit  = agent.policy.actor(state)
    dist   = Categorical(logit)
    value  = agent.policy.critic(state)[1]
    action = rand(dist)
    action, value, logit
end

function take_step(agent::PPO, state, action, reward, next_state, done, value, prob)
    push!(agent.memory.states, state)
    push!(agent.memory.actions, action)
    push!(agent.memory.rewards, reward)
    push!(agent.memory.next_states, next_state)
    push!(agent.memory.dones, done)
    push!(agent.memory.values, value)
    push!(agent.memory.probs, prob)
    agent.update_step += 1
    if agent.update_step % agent.update_every == 0
        train(agent)
        clear_memory(agent.memory)
    end
end

function train(agent::PPO)
    states      = vec2mat(agent.memory.states)
    next_states = vec2mat(agent.memory.next_states)
    values      = agent.memory.values
    actions     = CartesianIndex.(agent.memory.actions, 1:length(agent.memory.actions))
    for _ ∈ 1:agent.epochs
        ps = params(agent.policy)
        gs = gradient(ps) do
            v          = agent.policy.critic(states) |> vec 
            v_next     = agent.policy.critic(next_states) |> vec
            v_target   = agent.memory.rewards .+ agent.γ .* (1 .- agent.memory.dones) .* v_next
            delta      = v_target .- values
            gae        = 0
            T          = length(agent.memory.rewards)
            advantages = zeros(Float32, T)
            ignore() do
                for i ∈ T:-1:1
                    gae           = delta[i] + agent.γ * agent.λ * gae * (1 - agent.memory.dones[i])
                    advantages[i] = gae
                end
                #advantages    = (advantages .- mean(advantages)) ./ (std(advantages) .+ 1e-5)
            end
            
            returns       = advantages .+ values |> dropgrad
            logits        = agent.policy.actor(states)
            log_probs     = log.(logits[actions])
            old_log_probs = log.(agent.memory.probs)
            ratio         = exp.(log_probs .- old_log_probs)
            surr1         = ratio .* advantages
            surr2         = clamp.(ratio, 1 - agent.clip, 1 + agent.clip) .* advantages
            loss          = -mean(min.(surr1, surr2)) + 0.5 * Flux.mse(returns, v)
            loss
        end
        Flux.Optimise.update!(agent.opt, ps, gs)
    end
end

# ------------------------------- train ----------------------------------

function train_loop_snake()
    env          = SnakeGameEnv()
    state_dim    = length(get_state(env))
    action_dim   = 4
    lr           = 0.0002
    gamma        = 0.99
    lmbda        = 0.95
    clip         = 0.1
    update_every = 20
    epochs       = 10
    agent        = PPOAgent(state_dim, action_dim, lr, gamma, lmbda, clip, update_every, epochs)
    total_reward = 0
    for episode ∈ 1:100000
        state = reset_env(env)
        
        while !env.done
           
            action, value, prob = select_action(agent, state)
            env(action)
            next_state = get_state(env)
            total_reward += env.reward
            
            take_step(agent, state, action, env.reward, next_state, env.done, value, prob[action])
            state = next_state
            if env.done
                if env.score > env.record
                    env.record = env.score
                    actor_model = agent.policy.actor
                    @save "actor_model_snake.bson" actor_model
                    println("model improved, saved model!")
                end
                break
            end
            

        end
        if episode % 20 == 0
            avg_reward = total_reward / 20
            total_reward = 0
            @info "Episode : $episode | avg_reward : $avg_reward | record : $(env.record)"
        end
    end
    agent.policy
end

function train_loop_gym()
    env          = GymEnv("CartPole-v1")
    state_dim    = env.observation_space.shape[1]
    action_dim   = env.action_space.n
    lr           = 0.0002
    gamma        = 0.99
    lmbda        = 0.95
    clip         = 0.1
    update_every = 20
    epochs       = 10
    agent        = PPOAgent(state_dim, action_dim, lr, gamma, lmbda, clip, update_every, epochs)
    total_reward = 0
    for episode ∈ 1:10000
        state = reset!(env)
        for _ ∈ 1:1000
            action, value, prob = select_action(agent, state)
            next_state, reward, done, _ = step!(env, action - 1)
            total_reward += reward
            take_step(agent, state, action, reward, next_state, done, value, prob[action])
            state = next_state
            if done
                break
            end
        end

        if episode % 20 == 0
            avg_reward   = total_reward / 20
            total_reward = 0
            println("Episode : $episode | avg_reward : $avg_reward")
        end
    end
     

end

train_loop_gym()

in order to make Gym package work you need to modify some files
make sure you have installed gym latest version using pip
then install Gym in julia using add Gym
then go to julia/packages (where your all julia packages are installed) find Gym, go to src and edit env.jl like this

include("spaces.jl")
include("spec.jl")

mutable struct GymEnv
    name
    spec
    action_space
    observation_space
    reward_range
    gymenv
end

function GymEnv(id::String, continuous::Bool=false)
    gymenv = nothing
    try
        
        try 
            gymenv = gym.make(id, continuous=continuous)
        catch e
            
            gymenv = gym.make(id)
        end
    catch e
        error("Error received during the initialization of $id\n$e")
    end
    
    spec = Spec(gymenv.spec.id,
                gymenv.spec.reward_threshold,
                gymenv.spec.nondeterministic,
                gymenv.spec.max_episode_steps,
               )
    action_space = julia_space(gymenv.action_space)
    observation_space = julia_space(gymenv.observation_space)

    env = GymEnv(id, spec, action_space,
                 observation_space, gymenv.reward_range, gymenv)
    return env
end

reset!(env::GymEnv) = env.gymenv.reset()[1]
function render(env::GymEnv; mode="human")
    env.gymenv.render(mode)
end

function step!(env::GymEnv, action)
    ob, reward, done, information, _ = env.gymenv.step(action)
    return ob, reward, done, information
end

close!(env::GymEnv) = env.gymenv.close()

seed!(env::GymEnv, seed=nothing) = env.gymenv.seed(seed)

then reload the Gym package( you can use build Gym)
to test it with LunarLander-v2, you can change env name in train_loop_gym() function

note that i can solve the LunarLander-v2 env with different implementation of PPO that uses Monte Carlo estimation instead of generalized advantage estimation used in the current implementation, if you need the code for Monte Carlo, i can provide of course

Hello and welcome to the community!

You should probably spell out what PPO means, it’s not an acronym that most people in this forum can be expected to be familiar with.

I don’t have the time now to look at your code in details, sorry, but maybe it helps to have a look at another PPO implementation in julia. For loading gym (or other) environments, the ReinforcementLearning.jl ecosystem could also be useful.

im already familiar with that package:)

1 Like