Hi,
I’m solving a Q-learning problem.
I’m trying to a gradient of TD error. That is,
where \theta = (\theta_1, ..., \theta_N) is the network parameter, x, u, and x^{+} are state, control, successor state.
The difficulty comes from \min. That is,
I used to use Flux.jl to get a gradient automatically.
The first term, \frac{\partial Q^{\theta}}{\partial \theta_i} would be easy; just from
gs_theta = gradient(params(theta)) do
Q(x_next, u_star, \theta)
end
as shown in Flux.jl manual.
The second term would also be easy by using
gs_u = gradient(params(u_star)) do
Q(x_next, u_star, \theta)
end
However, for the last term, the auto-calculation of the gradient of \min seems not to be supported by Flux.jl.
It may be achievable by using DiffOpt.jl, namely, gs_u_star_theta_i
and the resulting gradient would be gs_theta + gs_u * [gs_u_star_theta_1, ..., gs_u_star_theta_N]
.
So I would like to manually tell that “the gradient is the above equation”.
My questions are:
- How can I manually give a gradient in Flux.jl?
- How can I merge gradients appropriately, for example, the
gs_u_star_theta_i
’s.
The above description would be poor. Please leave any comments