 # Flux Custom Loss Function Not Working Properly

Hello, I am trying to implement the A3C reinforcement learning algorithm in Flux.jl using a custom loss function described in the algorithm. I am having trouble with my custom loss function - the parameters of my model do not update when I compute gradients via my custom loss function. A demonstrative MWE is:

using Flux

# set up common model params
state_dim = 5
action_dim = 4
model = Chain(Dense(state_dim, 128, relu),
Dense(128, action_dim+1)
)
θ = params(model)
opt = ADAM()

# this loss function does not work
π_sa = 0.3
A_sa = 10
actor_loss_function(π_sa, A_sa) = log(π_sa)*A_sa
dθ = gradient(()->actor_loss_function(π_sa, A_sa), θ)
display(θ)
Flux.update!(opt, θ, dθ)
display(θ)  # unchanged

# this loss function works
s_t = rand(5)
a_t = rand(5)
loss(x, y) = Flux.Losses.mse(model(x), y)
dθ_mse = gradient(()->loss(s_t, a_t), θ)
display(θ)
Flux.update!(opt, θ, dθ_mse)
display(θ)  # changed


Does anyone have an idea of why my custom loss function is not working? Both loss functions return scalar quantities in this example. The actor_loss_function() seems to generally return negative values (π_sa = 0.3 is a probability which is usually less than 1 so the log() turns it negative) and Flux.Losses.mse generally returns a positive value if that makes a difference. Any thoughts/feedback is greatly appreciated.

Umm, I‘m not a Flux expert, but your custom loss does not call the model at all - in contrast to the mse. Also, if you want to maximize the probability, you should minmize (thats what all optimizers do) the negative log (assuming A_sa is always positive).

@maxfreu Thanks for the response. My (potentially flawed) understanding is that the parameters (θ) are linked to the model internally via the params() function. When you call update() on the parameters of the model, the model itself is updated implicitly. This is all based on my interpretation of Model parameters documentation.

Basically I don’t think I need to call the model in my loss function but I am not a Flux expert either. Also thanks for the comment on the optimization, I will keep that in mind.

I think the problem you have is that the function you take the derivative of simply doesn’t depend on theata. So the gradient is zero, hence no update.

That makes sense. What is the best way to “enforce” that dependence?

For reference, my model/network takes in some state and outputs a 5-node vector. The first 4 nodes are probabilities (π_sa) and the 5th node is just a value (v_s). I updated my network architecture to reflect this:

A3C_Output(in::Int64, out::Int64) =
A3C_Output(randn(out, in))

(m::A3C_Output)(x) = (softmax((m.W*x)[1:size(m.W)-1]),  # π_sa
(m.W*x)[size(m.W)])                # v_s

Flux.@functor A3C_Output
model = Chain(
Dense(state_dim, 128, relu),
A3C_Output(128, action_dim+1)
)


I am trying to update my model based on these two gradients (from the A3C paper). Letting \theta be the network parameters, we update the network parameters with the gradients d\theta and d\theta_v such that:

d\theta = \nabla_\theta \hspace{2pt}log(\pi(a_i|s_i,\theta))A_{sa}\\ d\theta_v = \nabla_\theta \hspace{2pt}A_{sa}^2

Where \pi(a_i|s_i,\theta) is the probability of some selected action (π_sa is a vector so \pi(a_i|s_i,\theta) is an element of that vector) and A_{sa} is just a scalar value based on v_s.

What I am getting at is that the loss function is defined by the network output but I am unsure on how to convey that programatically. Any thoughts? Thanks.

As a note, if I compute the gradients via

dθ = gradient(()->actor_loss_function(π_sa, A_sa), θ)
display(dθ.grads)


Returns:

IdDict{Any,Any} with 3 entries:
Float32[0.155933 -0.0405449 … -0.0203858 -0.0733626; -0.0666687 0.114271 … 0.… => nothing
Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0… => nothing
[-0.99497 -0.806638 … 1.58456 -0.522958; 0.597227 -0.304895 … -0.861661 0.827… => nothing


When I compute the gradients dθ_mse the values in the IdDict are not nothing.

When we say the model is updated implicitly, that means something like the following:

for param, grad in zip(θ, dθ)
update!(param, grad, learning_rate)
end


If grad is empty (which you can see it is in your last code snippet displaying the IdDict), then it follows that param will not change at all! There’s no magic here, Flux will not return a gradient for a parameter that is not used in the loss function because there is no derivative defined wrt. that parameter.

Thanks for the reply. I understand that if the parameter is not used in the loss function then the gradient is zero. Technically my loss function does use output from the model (and hence the parameterization \theta). This process is:

My network outputs a 5x1 vector. I transform the first four elements into probabilities using softmax() this is \pi(:|s_i,\theta). The last element is unmodified, this is v(s_i).

Later in my code, I select an element from the probability vector \pi(:|s_i,\theta) (based on the probabilities themselves). This selected element is \pi(a_i|s_i,\theta).

My loss function takes \left(\pi(a_i|s_i,\theta),v(s_i)\right) (both scalars) as arguments.

What I am trying to say is that the loss function is using the outputs from the model after some modification. I think the crux of my problem is that I don’t know how to “make Flux” understand this. Does that make sense? Any insight is much appreciated.

Why not run the model, generate the outputs and modify them in the loss function? Currently all actor_loss_function does is multiply two scalars together, so I assume there’s more you want to add in there.

If I am understanding you correctly, I think that I already am doing that. My original MWE is the relevant portion of a much larger reinforcement learning algorithm implementation (A3C, see algorithm below if interested).

Basically I have some agents that interact with an environment, get a state, and select an action by feeding that state through the network. They do this a bunch and I process/store the network output which I then use for my loss function.

The gradients I am updating are:

Note that R is a scalar quantity as well - this means that my loss function is as simple as shown in actor_loss_function in the above MWE. R, V(s_i,\theta) and \pi(a_i|s_i,\theta) are all based in some way on the agent interactions with the environment based on model inputs/outputs which should means that gradients with respect to \theta are non-zero.

Full A3C algorithm for reference:

I’m roughly aware of how A3C and other semi-gradient methods function. I think you may be mistaken about how model fits into \pi(a_i|s_i,\theta). My understanding is that running the policy \pi for a particular action-state pair involves a forward pass through the model. Otherwise, why parameterize \pi in terms of \theta at all? If you accept this interpretation, then it’s easy to see that:

1. Anywhere \pi is used, you’ll be invoking the forward pass of model.
2. Thus, the gradient calculation d\theta \leftarrow \ldots will call model as part of the loss calculation.
3. The above points also apply to V and d\theta_v \leftarrow \ldots as well because V is a network.

It may be worth trying to implement a simpler semi-gradient method with no asynchronous parts and only one set of parameters \theta first to demonstrate that the above is true. You might also derive some insights from the examples in GitHub - JuliaReinforcementLearning/ReinforcementLearningZoo.jl.

I agree with everything in the first paragraph - well stated. Each step of the simulation does indeed involve a forward pass through model; each action is selected based on the probabilities output from the model based on the current state. The workflow is basically:

1. I collect a history of these action-state pairings via a forward pass through model (without updating the parameterization \theta).
2. I then attempt to train \theta after step 1. based on the collected values.

What I think is problematic in what you are proposing is the forward pass through model during the training step - you’re potentially generating different state-action pairings without realizing the associated reward. The loss function is called for each “piece” of data that I train on and theta is updated accordingly. This means that the outputted action selection probabilities for each state will probably differ between the training step 2 and the associated state/action pairs collected in step 1. Because the action probabilities are different I am not entirely sure if the Rewards (R) are valid anymore.

Does that make sense? Thanks again for the insightful response.

Also noting the simpler method recommendation - this is good advice and I’ll note that my implementation isn’t really asynchronous. It is involving multi agents all in the same environment. They all interact with the environment, collect data, then all data is trained on at once - I only maintain one global model/parameterization \theta. A3C is the motivating algorithm but I am implementing it a little loosely. I will also check out the RL zoo repo you sent.

I don’t believe this is the case though. If we assume that the model’s forward pass does not mutate any parameters on its own (which it should not), then the output \pi(a_i|s_i,\theta) should be exactly the same in step 1 and 2 for a given i. Note that the step labelled “Accumulate gradients wrt \theta^\prime: …” isn’t actually updating \theta^\prime (or \theta^\prime_v) in the loop, but rather the gradients d\theta. In other words, "theta is [not] updated accordingly" while the policy \pi is still being queried to generate/accumulate the gradients in the loop, but only once that loop terminates (per the line “Perform asynchronous update …”).

What you could do as an optimization to avoid recalculating \pi(a_i|s_i,\theta) in step 2 is to save the gradients for each timestep in step 1 (the line “Perform a_t according to policy …”). That is merely an optimization, however, and should not affect the final output.

Alright, I think I understand what you are proposing. Putting it into code, I think I would implement it this way. Note my custom struct above is a little messed up so I’ll redefine here:

# define custom A3C output model layer
struct a3c
W
end
a3c(in::Integer, out::Integer) =
a3c(randn(out, in))
a3c(in::Int64, out::Int64) = a3c(randn(out,in))
Flux.@functor a3c

# redefine model with new struct
state_dim = 5
action_dim = 4
model = Chain(Dense(state_dim, 128,relu),
a3c(128,action_dim+1)
)
θ = params(model)

# [have agents interact with environment and collect data]
# ...
# dummy data for one interaction:
s_t = rand(5)
R = 10
action_index = 2  # action 2 was selected from outputted probabilities during agent/environment interaction

# duplicate data as an example
d = [(s_t, R, action_index),
(s_t, R, action_index) ]

# define updated loss function that calls model
function actor_loss_function(data)
s_t, R, action_index = data
model_output = model(s_t)
π_sa = model_output[action_index]
v_s = model_output
return log(π_sa)*(R-v_s)
end

# accumulate gradients, do initial gradient outside of loop
dθ = gradient(()->actor_loss_function(d), θ)
for data in d[2:length(d)]
dθ += gradient(()->actor_loss_function(data), θ)
# don't call update!() here because we want to preserve model state during computation of gradients
end

# finally call update once at the end of the run
display(θ)
Flux.update!(opt, θ, dθ)
display(θ)  # changed


Is this what you’re describing? If so I am still a little confused about the “gradient accumulation”. Flux/Zygote doesn’t seem to have a way to add gradients that I know of (I will look into this more) aka:

julia> dθ+dθ
ERROR: MethodError: no method matching +(::Zygote.Grads, ::Zygote.Grads)
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:538
+(::ChainRulesCore.One, ::Any) at /Users/riomcmahon/.julia/packages/ChainRulesCore/7d1hl/src/differential_arithmetic.jl:94
+(::ChainRulesCore.DoesNotExist, ::Any) at /Users/riomcmahon/.julia/packages/ChainRulesCore/7d1hl/src/differential_arithmetic.jl:23
...
Stacktrace:
 top-level scope at REPL:1


Any insight into a way to accumulate gradients for all the data so I only call update!() once? Thanks again for the help.

Adding Grads is supported as of Zygote 0.6.4 (should be compatible with the latest Flux, just needs an ]up). See Zygote.jl/utils.md at v0.6.7 · FluxML/Zygote.jl · GitHub.

Awesome, thanks for pointing me in that direction. I think it is working but I need to train my model to confirm. Currently my training function looks like:

function A3C_policy_train(model)
# create loss functions
function actor_loss_function(R, s_t, a_t)
model_output = model.RL.params.model(s_t)
π_sa = model_output[a_t]
v_s = model_output
return log(π_sa)*(R-v_s)
end

function critic_loss_function(R, s_t)
model_output = model.RL.params.model(s_t)
v_s = model_output
return (R-v_s)^2

end

opt = ADAM()
global_reward = 0
display(model.RL.params.θ)
for i in 1:model.num_agents

# compute initial gradients at end of series (tmax-1)
tmax = model.ModelStep
_ , R = model.RL.params.model(model.RL.params.s_t[i, :, tmax-1])
s_t = model.RL.params.s_t[i, :, tmax-1]
a_t = model.RL.params.a_t[i, tmax-1]
dθ = gradient(()->actor_loss_function(R, s_t, a_t), model.RL.params.θ)
dθ_v = gradient(()->critic_loss_function(R, s_t), model.RL.params.θ)

# accumulate gradients for rest of series, starting at tmax-2
for t in reverse(1:tmax-2)
R = model.RL.params.r_sa[i, t] + model.RL.γ*R
s_t = model.RL.params.s_t[i, :, t]
a_t = model.RL.params.a_t[i, t]

dθ .+= gradient(()->actor_loss_function(R, s_t, a_t), model.RL.params.θ)
dθ_v .+= gradient(()->critic_loss_function(R, s_t), model.RL.params.θ)
end
update!(opt, model.RL.params.θ, dθ)
update!(opt, model.RL.params.θ, dθ_v)
global_reward += sum(model.RL.params.r_sa[i, :])
end
display(model.RL.params.θ)  # changed
return global_reward
end


Note that I have my model and model parameterization θ wrapped up in some structs; this is similar to my state/action/reward histories. Two more questions:

1. Do you know of a way to initialize an empty gradient so I can avoid code duplication and just put everything in the for t in 1:model.reverse(1:tmax-2)... section of code? I looked a little bit and didn’t see anything in the Zygote documentation.
2. Does the above code snippet seem to capture what we’ve discussed? I only call update!() after accumulating gradients but I seem to have taken a performance hit. Any performance tips would be greatly appreciated.

Thank you very much for your feedback/help. I’ve been losing my mind over this problem for several days and I really appreciate you taking the time to help me work through it.

Also @maxfreu I am not sure if you’ve been following this thread but your initial answer of calling model in the loss function ended up being what was wrong. Thanks for the feedback and I apologize for dismissing your initial answer - I guess I misinterpreted the Flux documentation.

Good question, I think you should be able to construct an empty grads object like so:

dθ = Zygote.Grads(IdDict(ps => nothing for ps in model.RL.params.θ), model.RL.params.θ)

1 Like

No problem, ToucheSir deserves it; he put more love into his answers @maxfreu I’m a bit unclear on what you brought up in the statement above. As I understand it, a typical gradient descent method for something like an image classifier will compare truth values to the model output; loss goes down as the models predictions better match the truth \implies loss\in[0,\infty).

My use case is a reinforcement learning model where an agent interacts with an environment and receives rewards/penalties for certain actions. A well trained network will give agent actions that maximize reward. A simulation of one of my agents can have large positive rewards if it performs well or large negative rewards from incurring lots of penalties due to poor performance \implies loss\in(-\infty,\infty).

I’m unclear on how to convert my reward function to a loss function. From an RL perspective I want to maximize reward but from a gradient descent perspective I am unclear if Zygote is trying to find the smallest (potentially negative) loss. Basically should I multiply my rewards by a -1 for Zygote? This is probably a dumb question but I’m confused and couldn’t find anything in the documentation. Thanks.