Reinforcement learning packages for CartPole example with Julia v1.11 or v1.10?

I’ve been looking at reinforcement learning packages in Julia. Some of them are unfortunately slightly out of maintenance and have broken dependencies. Does anyone have a fully working CartPole training example (a well known pedagogical example in reinforcement learning) with Julia v1.10 or v1.11? I found the package DeepQLearning.jl useful, though I struggled to find a working implementation of the Cartpole environment with the required POMDPs.jl interface. (It wouldn’t be much work to re-implement it from scratch, but I’d like to skip the work if possible.)

1 Like

Have you seen GitHub - JuliaReinforcementLearning/ReinforcementLearning.jl: A reinforcement learning package for Julia?

I got it working. Here’s an animation of a trained agent balancing a pole on a cart, after training with 8000 steps of interactions. The arrow indicates if the agent is pushing the cart to the left or to the right.
anim

I used ReinforcementLearning.jl which provided the CartPole environment. However, most of the NN solvers in this package seem to be broken due to API changes in dependencies, therefore I used the solver in DeepQLearning.jl after converting the CartPole environment to its interface. My full code is dumped here in case anyone wants to try. First, install dependencies in a fresh virtual environment with Julia 1.11:

]add DeepQLearning, Flux, QuickPOMDPs, POMDPTools, Plots, ReinforcementLearningEnvironments

Then run

import DeepQLearning: DeepQLearningSolver, solve
import Flux: Chain, Dense, relu, gelu
import QuickPOMDPs: QuickPOMDP
import POMDPTools: Deterministic, ImplicitDistribution, EpsGreedyPolicy, LinearDecaySchedule
import Plots: plot
import ReinforcementLearningEnvironments: CartPoleEnv, state, act!, reset!, is_terminated, reward

# Convert ReinforcementLearningEnvironments.CartPoleEnv() to the POMDPs.jl interface
cartpole_mdp = QuickPOMDP(
    actions = [1, 2],
    discount = 0.99,
    gen = function (s, a, rng)
        sp = deepcopy(s)
        act!(sp, a)
        o = state(sp)
        r = reward(sp)
        (;sp, o, r)
    end,
    initialstate = ImplicitDistribution((rng) -> CartPoleEnv()),
    isterminal = is_terminated,
    initialobs = s -> Deterministic(state(s))
)

exploration = EpsGreedyPolicy(cartpole_mdp, LinearDecaySchedule(start=1.0, stop=0.05, steps=8000));
# workaround: semi-colon above suppresses buggy printing
model = Chain(Dense(4, 128, gelu), Dense(128, 128, gelu), Dense(128, 64, gelu), Dense(64, 2))
solver = DeepQLearningSolver(qnetwork = model, max_steps=8000, 
                             exploration_policy = exploration, learning_rate=0.001,
                             log_freq=1000, eval_freq=1000, batch_size=128,
                             max_episode_length=200, train_start=1000, double_q=true,
                             dueling=true, prioritized_replay=true, buffer_size=5000)
@time policy = solve(solver, cartpole_mdp) # about 22 seconds excluding compilation

# Now test the performance
function action_choice(policy, env)
    action_vals = policy.qnetwork(convert(Vector{Float32}, state(env)))
    argmax(action_vals)
end

max_sim_steps = 200
e = CartPoleEnv()
total_reward = 0f0
for i in 1:max_sim_steps+1
    a = action_choice(policy, e)
    act!(e, a)
    plot(e)
    global total_reward += reward(e)
    if is_terminated(e)
        break
    end
    sleep(0.05) # animation time step
end
@show total_reward # Reaches the perfect score 200, though this depends
                   # on the random initialization of the neural network

3 Likes