Manually giving gradients in Flux.jl

I’m solving a Q-learning problem.
I’m trying to a gradient of TD error. That is,

TD-error = Q^{\theta}(x, u) - (r(x, u) + \min_{v} Q^{\theta}(x^{+}, v))

where \theta = (\theta_1, ..., \theta_N) is the network parameter, x, u, and x^{+} are state, control, successor state.

The difficulty comes from \min. That is,

\frac{Q^{\theta}(x^{+}, u^{*}(x^+, \theta))}{\partial \theta_i} = \frac{\partial Q^{\theta}}{\partial \theta_i} + \frac{\partial Q^{\theta}}{\partial u} \frac{\partial u^{*}}{\partial \theta_i}

I used to use Flux.jl to get a gradient automatically.
The first term, \frac{\partial Q^{\theta}}{\partial \theta_i} would be easy; just from

gs_theta = gradient(params(theta)) do
    Q(x_next, u_star, \theta)

as shown in Flux.jl manual.
The second term would also be easy by using

gs_u = gradient(params(u_star)) do
    Q(x_next, u_star, \theta)

However, for the last term, the auto-calculation of the gradient of \min seems not to be supported by Flux.jl.
It may be achievable by using DiffOpt.jl, namely, gs_u_star_theta_i and the resulting gradient would be gs_theta + gs_u * [gs_u_star_theta_1, ..., gs_u_star_theta_N].
So I would like to manually tell that “the gradient is the above equation”.

My questions are:

  1. How can I manually give a gradient in Flux.jl?
  2. How can I merge gradients appropriately, for example, the gs_u_star_theta_i's.

The above description would be poor. Please leave any comments :slight_smile:

Your θ is the network parameter but where is your nn?

Q(x, u, theta) is itself a network.
More precisely, Q is constructed as a function of nn(x, theta) and u in my case to make sure that Q is convex in u.

Flux relies on Zygote for autodiff. Zygote now uses ChainRules for the actual set of math rules to define the basic gradients transformations. The preferred way to define custom gradients now is through ChainRulesCore’s rrule:

(also see Custom Adjoints · Zygote for the alternative legacy method and a bit more verbose explanation than what I just gave)

1 Like