I am building a neural net and want to be able to differentiate the network with respect to the model input . I would then like to compute the gradient(with respect to the network weights and bias) of this derivative. I can get the input derivatives by doing the following:
using Flux
using ForwardDiff
u(x, W, b) = sum(σ.(W*x + b))
# ∂ₓu
uₓ(x, W, b) = ForwardDiff.gradient(z->u(z, W, b), x)[1]
# ∂ₓₓu
uₓₓ(x, W, b) = ForwardDiff.gradient(z->uₓ(z, W, b), x)[1]
Now, I want to find the gradient w.r.t the network parameters. I can do this if I use ForwardDiff,
W₀ = rand(1,2)
b₀ = rand(1)
x₀ = rand(2)
∇W₀ = ForwardDiff.gradient(W₀) do W₀
uₓₓ(x₀, W₀, b₀)
end
but for this computation I would like to use reverse-mode AD. However, using the Flux gradient function returns a grad object filled with ‘nothing’. I know that this question has surfaced in some form or another before but I have not been able to figure out how to get this working. I am still very new to all of this but the fact that I can compute the gradient with ForwardDiff makes me think it can be done with reverse-mode AD. I have seen suggestions to define a custom chain rule but I am struggling to do so. (I am not familiar with ChainRules.jl and also unsure how to write a rrule that can deal with a ForwardDiff call) I know that @ChrisRackauckas suggested looking over the ZygoteRules given at
but my lack of experience is making it hard to translate this to chainrules. Anyways, if anyone has any suggestions I would greatly appreciate it. Thanks in advance!