Updating only a subset of parameters in a neural network using Flux

I am currently playing around with actor-critic methods. In my current problem, the actor has four outputs:

pol = Chain(Dense(6, 32, relu), Dense(32, 4))

However, depending on the state only a subset of the actions are legal. So for choosing an action I use:

x = get_featurevector(s, pos_a)
probs = softmax(pol(x)[pos_a])
action = sample(1:length(probs), Weights(probs))

where s is the current state and pos_a is a subset of [1,2,3,4] denoting the possible actions.
Now, when I update the actor (pol), my idea would be to ignore the outputs and the corresponding weights & biases that the actor weren’t allowed to pick. I would like to do something like:

ps = params(pol[1], pol[2].weight[pos_a, :], pol[2].bias[pos_a, :])
gs = gradient(ps) do
	Flux.logitcrossentropy(pol(x), target)
end
update!(opt(lr), ps, gs)

Does anyone how I can make this code work properly?
I hope it is clear what I want to achieve - I tried to keep the explanation as short as possible. If you need a more detailed explanation, let me know!

Just multiplying the output of the actor with a boolean mask is not an option?

pol(x) .* (1:4 .∈ Ref(pos_a))

Thank you once again trahflow!
As I am using the logitcrossentropy, I used

pol(x) + mask

with

mask = zeros(4)
for i in 1:4
	if !(i in pos_a)
		mask[i] = -Inf
	end
end
1 Like