Reinforcement learning packages for CartPole example with Julia v1.11 or v1.10?

I got it working. Here’s an animation of a trained agent balancing a pole on a cart, after training with 8000 steps of interactions. The arrow indicates if the agent is pushing the cart to the left or to the right.
anim

I used ReinforcementLearning.jl which provided the CartPole environment. However, most of the NN solvers in this package seem to be broken due to API changes in dependencies, therefore I used the solver in DeepQLearning.jl after converting the CartPole environment to its interface. My full code is dumped here in case anyone wants to try. First, install dependencies in a fresh virtual environment with Julia 1.11:

]add DeepQLearning, Flux, QuickPOMDPs, POMDPTools, Plots, ReinforcementLearningEnvironments

Then run

import DeepQLearning: DeepQLearningSolver, solve
import Flux: Chain, Dense, relu, gelu
import QuickPOMDPs: QuickPOMDP
import POMDPTools: Deterministic, ImplicitDistribution, EpsGreedyPolicy, LinearDecaySchedule
import Plots: plot
import ReinforcementLearningEnvironments: CartPoleEnv, state, act!, reset!, is_terminated, reward

# Convert ReinforcementLearningEnvironments.CartPoleEnv() to the POMDPs.jl interface
cartpole_mdp = QuickPOMDP(
    actions = [1, 2],
    discount = 0.99,
    gen = function (s, a, rng)
        sp = deepcopy(s)
        act!(sp, a)
        o = state(sp)
        r = reward(sp)
        (;sp, o, r)
    end,
    initialstate = ImplicitDistribution((rng) -> CartPoleEnv()),
    isterminal = is_terminated,
    initialobs = s -> Deterministic(state(s))
)

exploration = EpsGreedyPolicy(cartpole_mdp, LinearDecaySchedule(start=1.0, stop=0.05, steps=8000));
# workaround: semi-colon above suppresses buggy printing
model = Chain(Dense(4, 128, gelu), Dense(128, 128, gelu), Dense(128, 64, gelu), Dense(64, 2))
solver = DeepQLearningSolver(qnetwork = model, max_steps=8000, 
                             exploration_policy = exploration, learning_rate=0.001,
                             log_freq=1000, eval_freq=1000, batch_size=128,
                             max_episode_length=200, train_start=1000, double_q=true,
                             dueling=true, prioritized_replay=true, buffer_size=5000)
@time policy = solve(solver, cartpole_mdp) # about 22 seconds excluding compilation

# Now test the performance
function action_choice(policy, env)
    action_vals = policy.qnetwork(convert(Vector{Float32}, state(env)))
    argmax(action_vals)
end

max_sim_steps = 200
e = CartPoleEnv()
total_reward = 0f0
for i in 1:max_sim_steps+1
    a = action_choice(policy, e)
    act!(e, a)
    plot(e)
    global total_reward += reward(e)
    if is_terminated(e)
        break
    end
    sleep(0.05) # animation time step
end
@show total_reward # Reaches the perfect score 200, though this depends
                   # on the random initialization of the neural network

3 Likes