How to manually update the params in a `Flux.Chain` neural network?

I have a simple Flux neural network, e.g.

nn = Chain(
   Dense(20, 20),
   Dense(20, 2)

I wish to compute the gradient of the log(nn(input)) wrt to Flux.Params(nn). I can’t use the Flux.train! and use a loss function, as I need to compute the update manually, (in particular I am doing the update by following the RL algorithm REINFORCE).

So I try to compute the gradient

using Zygote
function delta(input, i)
  gradient(()->-log(nn(input)[i]), Flux.Params(nn))

first_grad = delta(rand(20), 1)

So far so good, I can see that first_grad is of type Zygote.grads.

But now, how do I update the params in the direction of the gradient?

I tried

opt = ADAM()
Flux.update!(opt, Flux.Params(nn), first_grad )

and it gives error

ERROR: KeyError: key Float32[0.08519302 0.011779553 … -0.004205899 0.085937634; 0.0063124006 0.1696027 … 0.09743015 -0.33658051; … ; -0.18922885 -0.22327383 … -0.012914175 -0.27949685; -0.26934287 0.18481864 … 0.2515226 -0.09317841] not found

I also tried

p = Flux.params(nn)
p .+= 0.001first_grad 

which gives error

ERROR: MethodError: no method matching *(::Float64, ::Zygote.Grads)
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:538
  *(::ChainRulesCore.One, ::Any) at C:\Users\RTX2080\.julia\packages\ChainRulesCore\j0yny\src\differential_arithmetic.jl:78    
  *(::Float64, ::Float64) at float.jl:405

Optimizers like ADAM use a IdDict internally, so you shouldn’t create new parameters in each update. Just create them once and reuse them when updating.

1 Like