Hi,
I’m writing a Deep RL (PPO) implementation using Lux.jl and want to add support for recurrent networks. My issue is that I have a batch of environments and at every time step some of them may reset due to the agent failing at the task. When this happens, I’d like to reset the hidden state of the networks controlling the corresponding environments.
So I essentially have batch-sized length reset_mask::Vector{Bool}
and would like to reset the state of all recurrent layers in my chain to initialstate, but only for the batches where reset_mask[t]
is true
. Is there some utility for that?
It would also be nice to be able to do the same thing for Recurrence
, but passing a time x batch_size matrix for a reset mask.
Perhaps there some trivial way to do this to the (x, hidden)
arrays directly before passing them that I’m missing?
Thanks.