I am currently playing around with actor-critic methods. In my current problem, the actor has four outputs:
pol = Chain(Dense(6, 32, relu), Dense(32, 4))
However, depending on the state only a subset of the actions are legal. So for choosing an action I use:
x = get_featurevector(s, pos_a)
probs = softmax(pol(x)[pos_a])
action = sample(1:length(probs), Weights(probs))
where s
is the current state and pos_a
is a subset of [1,2,3,4]
denoting the possible actions.
Now, when I update the actor (pol
), my idea would be to ignore the outputs and the corresponding weights & biases that the actor weren’t allowed to pick. I would like to do something like:
ps = params(pol[1], pol[2].weight[pos_a, :], pol[2].bias[pos_a, :])
gs = gradient(ps) do
Flux.logitcrossentropy(pol(x), target)
end
update!(opt(lr), ps, gs)
Does anyone how I can make this code work properly?
I hope it is clear what I want to achieve - I tried to keep the explanation as short as possible. If you need a more detailed explanation, let me know!