I’ve been looking at reinforcement learning packages in Julia. Some of them are unfortunately slightly out of maintenance and have broken dependencies. Does anyone have a fully working CartPole training example (a well known pedagogical example in reinforcement learning) with Julia v1.10 or v1.11? I found the package DeepQLearning.jl useful, though I struggled to find a working implementation of the Cartpole environment with the required POMDPs.jl interface. (It wouldn’t be much work to re-implement it from scratch, but I’d like to skip the work if possible.)
Have you seen GitHub - JuliaReinforcementLearning/ReinforcementLearning.jl: A reinforcement learning package for Julia?
I got it working. Here’s an animation of a trained agent balancing a pole on a cart, after training with 8000 steps of interactions. The arrow indicates if the agent is pushing the cart to the left or to the right.
I used ReinforcementLearning.jl which provided the CartPole environment. However, most of the NN solvers in this package seem to be broken due to API changes in dependencies, therefore I used the solver in DeepQLearning.jl after converting the CartPole environment to its interface. My full code is dumped here in case anyone wants to try. First, install dependencies in a fresh virtual environment with Julia 1.11:
]add DeepQLearning, Flux, QuickPOMDPs, POMDPTools, Plots, ReinforcementLearningEnvironments
Then run
import DeepQLearning: DeepQLearningSolver, solve
import Flux: Chain, Dense, relu, gelu
import QuickPOMDPs: QuickPOMDP
import POMDPTools: Deterministic, ImplicitDistribution, EpsGreedyPolicy, LinearDecaySchedule
import Plots: plot
import ReinforcementLearningEnvironments: CartPoleEnv, state, act!, reset!, is_terminated, reward
# Convert ReinforcementLearningEnvironments.CartPoleEnv() to the POMDPs.jl interface
cartpole_mdp = QuickPOMDP(
actions = [1, 2],
discount = 0.99,
gen = function (s, a, rng)
sp = deepcopy(s)
act!(sp, a)
o = state(sp)
r = reward(sp)
(;sp, o, r)
end,
initialstate = ImplicitDistribution((rng) -> CartPoleEnv()),
isterminal = is_terminated,
initialobs = s -> Deterministic(state(s))
)
exploration = EpsGreedyPolicy(cartpole_mdp, LinearDecaySchedule(start=1.0, stop=0.05, steps=8000));
# workaround: semi-colon above suppresses buggy printing
model = Chain(Dense(4, 128, gelu), Dense(128, 128, gelu), Dense(128, 64, gelu), Dense(64, 2))
solver = DeepQLearningSolver(qnetwork = model, max_steps=8000,
exploration_policy = exploration, learning_rate=0.001,
log_freq=1000, eval_freq=1000, batch_size=128,
max_episode_length=200, train_start=1000, double_q=true,
dueling=true, prioritized_replay=true, buffer_size=5000)
@time policy = solve(solver, cartpole_mdp) # about 22 seconds excluding compilation
# Now test the performance
function action_choice(policy, env)
action_vals = policy.qnetwork(convert(Vector{Float32}, state(env)))
argmax(action_vals)
end
max_sim_steps = 200
e = CartPoleEnv()
total_reward = 0f0
for i in 1:max_sim_steps+1
a = action_choice(policy, e)
act!(e, a)
plot(e)
global total_reward += reward(e)
if is_terminated(e)
break
end
sleep(0.05) # animation time step
end
@show total_reward # Reaches the perfect score 200, though this depends
# on the random initialization of the neural network