[ANN] FluxTraining.jl

Hey Peter!

It’s definitely possible to use this with reinforcement learning. You can implement custom training logic by creating a new Phase and then implementing fitepochphase! and/or fitbatchphase! for it. The default implementation should be a good starting point. As you can see, it simply loops over the data iterator, but you can overwrite that. If you only want to change the epoch (i.e. data iteration) logic, you can make your phase be a subtype of AbstractTrainingPhase, that way it will use the regular fitbatchphase! definition. To make it work with the callbacks, you should also throw the necessary events as is done in the default implementation.

Then you would simply call fit!(learner, ReinforcementPhase()) (or what you called the phase).

Making a tutorial on this for the documentation is on my to-do list.


Adding Weights&Biases support should be even easier, as there is an interface specifically for creating new logger “backends” that can be used with the logging callbacks. The implementation of TensorBoardBackend should give enough info.

The implementation boils down to implementing log_to methods for the various types that can be logged. Since there is no native Julia client, the easiest way to connect this to W&B would be to use PyCall.jl to wrap the Python client.

Let me know if you try yourself on either of these, and feel free to ask for more information :slight_smile:

3 Likes