I have a simple Flux neural network, e.g.
nn = Chain(
Dense(20, 20),
Dense(20, 2)
)
I wish to compute the gradient of the log(nn(input))
wrt to Flux.Params(nn)
. I can’t use the Flux.train!
and use a loss function, as I need to compute the update manually, (in particular I am doing the update by following the RL algorithm REINFORCE).
So I try to compute the gradient
using Zygote
function delta(input, i)
gradient(()->-log(nn(input)[i]), Flux.Params(nn))
end
first_grad = delta(rand(20), 1)
So far so good, I can see that first_grad
is of type Zygote.grads
.
But now, how do I update the params in the direction of the gradient?
I tried
opt = ADAM()
Flux.update!(opt, Flux.Params(nn), first_grad )
and it gives error
ERROR: KeyError: key Float32[0.08519302 0.011779553 … -0.004205899 0.085937634; 0.0063124006 0.1696027 … 0.09743015 -0.33658051; … ; -0.18922885 -0.22327383 … -0.012914175 -0.27949685; -0.26934287 0.18481864 … 0.2515226 -0.09317841] not found
I also tried
p = Flux.params(nn)
p .+= 0.001first_grad
which gives error
ERROR: MethodError: no method matching *(::Float64, ::Zygote.Grads)
Closest candidates are:
*(::Any, ::Any, ::Any, ::Any...) at operators.jl:538
*(::ChainRulesCore.One, ::Any) at C:\Users\RTX2080\.julia\packages\ChainRulesCore\j0yny\src\differential_arithmetic.jl:78
*(::Float64, ::Float64) at float.jl:405
...