PythonCall: Can't get attribute 'CallbackValue'

Hello,

I have a simulator that I wrote in Julia using Agents.jl and I’m trying to do some reinforcement learning using Ray and using PythonCall for the interop. (I know there are some reinforcement learning packages in Julia like ReinforcementLearning.jl, but I went down that path and it didn’t work for me.)

Specifically, I’m writing a custom environment using my simulator that will be passed to Ray to do the training. I’ve managed to write a little test environment and am getting this error:

ERROR: Python: RaySystemError: System error: Can't get attribute 'CallbackValue' on <module 'juliacall' from '/home/jesse/git/aerophase/.CondaPkg/env/lib/python3.10/site-packages/juliacall/__init__.py'>
traceback: Traceback (most recent call last):
  File "/home/jesse/git/aerophase/.CondaPkg/env/lib/python3.10/site-packages/ray/_private/serialization.py", line 369, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
  File "/home/jesse/git/aerophase/.CondaPkg/env/lib/python3.10/site-packages/ray/_private/serialization.py", line 252, in _deserialize_object
    return self._deserialize_msgpack_data(data, metadata_fields)
  File "/home/jesse/git/aerophase/.CondaPkg/env/lib/python3.10/site-packages/ray/_private/serialization.py", line 207, in _deserialize_msgpack_data
    python_objects = self._deserialize_pickle5_data(pickle5_data)
  File "/home/jesse/git/aerophase/.CondaPkg/env/lib/python3.10/site-packages/ray/_private/serialization.py", line 197, in _deserialize_pickle5_data
    obj = pickle.loads(in_band)
AttributeError: Can't get attribute 'CallbackValue' on <module 'juliacall' from '/home/jesse/git/aerophase/.CondaPkg/env/lib/python3.10/site-packages/juliacall/__init__.py'>

In my module file I’m importing what I need with:

# Python imports
using PythonCall
const gym = PythonCall.pynew()
const ray = PythonCall.pynew()
const ppo = PythonCall.pynew()

function __init__()
    PythonCall.pycopy!(gym, pyimport("gymnasium"))
    PythonCall.pycopy!(ray, pyimport("ray"))
    PythonCall.pycopy!(ppo, pyimport("ray.rllib.algorithms.ppo"))
end

My test environment is in it’s own file and looks like this:

TestEnv = pytype("TestEnv", (gym.Env,), [
    pyfunc(
        name="__init__",
        function (self, config)
            self.end_pos = config["end_pos"]
            self.cur_pos = 0
            self.action_space = gym.spaces.Discrete(2)
            self.observation_space = gym.spaces.Discrete(2)
            # Set the seed. This is only used for the final (reach goal) reward.
            self.reset()
            return
        end,
    ),
    pyfunc(
        name="reset",
        function (self)
            self.cur_pos = 0
            return [self.cur_pos], Dict()
        end
    ),
    pyfunc(
        name="step",
        function (self, action)
            if pyeq(Bool, action, 0) && pygt(Bool, self.cur_pos, 0)
                self.cur_pos -= 1
            elseif pyeq(Bool, action, 1)
                self.cur_pos += 1
            end
            done = truncated = pyge(Bool, self.cur_pos, self.end_pos)
            return (
                [self.cur_pos],
                done ? rand() * 2.0 : -0.1,
                done,
                truncated,
                Dict(),
            )
        end,
    ),
]
)

Finally, the script that I’m running looks like this:

AeroPhase.ray.init(ignore_reinit_error=true)
test_env_creator = pyfunc(
    name="test_env_creator",
    function (config)
        return AeroPhase.TestEnv(config)
    end,
)
AeroPhase.ray.tune.registry.register_env("TestEnv", test_env_creator)
algo = AeroPhase.ppo.PPO(
    env=AeroPhase.TestEnv,
    config=pydict(
        Dict(
            "framework" => "torch",
            "env_config" => pydict(Dict("end_pos" => 10)))
    )
)

Where AeroPhase is the name of my module.

Digging through the PythonCall/JuliaCall repo it looks like CallbackValue should be built upon init, but it’s somehow not included for the serialization that Ray is doing to work. There is probably a step I’m missing since I just started using this package 2 days ago (and Julia 2 months ago), but I can’t seem to piece it together from the docs and source.

Thanks in advance for the help!