Question about writting custom training function using Flux.jl


I am trying to write a custom training function instead of using Flux!train, following the documentation, but it will pop up error message “Only reference types can be differentiated with Params”.

Here is my code (modified based on trebuchet example in model-zoo):

using Flux
using Zygote
using Statistics
using Random

function shoot( angle, weight)


model = Chain(Dense(1, 16, σ),
              Dense(16, 64, σ),
              Dense(64, 16, σ),
              Dense(16, 2)) |> f64

θ = params(model)

function loss( target)
    angle, weight = model([target]) 
    angle = σ(angle)*90
    weight = weight + 200
    (shoot( angle, weight) - target)^2

DIST  = (20, 100)	# Maximum target distance

target() = (rand()*(DIST[2]-DIST[1])+DIST[1])

meanloss() = mean(sqrt(loss(target())) for i = 1:100)

opt = ADAM()
dataset = (target() for i = 1:2000)

@time Flux.train!(loss, θ, dataset, opt, cb = () -> println("meanloss = ",meanloss(),"; W1 = ",θ[1][1],"; b1 = ",θ[2][1]))

It can work by now. Then I wrote a custom training function and run the training:

function my_custom_train!(loss, ps, data, opt)
  local training_loss
  ps = Params(ps)
  for d in data
    gs = gradient(ps) do
      training_loss = loss(d...)
      return training_loss
    println("training_loss = ",training_loss,"gradient[1] = ",gs[1],"; W1 = ",[1][1],"; b1 = ",[2][1])
    Flux.update!(opt, ps, gs)

@time my_custom_train!(loss, θ, dataset, opt)

I got the following error:

ERROR: Only reference types can be differentiated with `Params`.
 [1] error(::String) at .\error.jl:33
 [2] getindex(::Zygote.Grads, ::Int64) at C:\Users\maxiao\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:142
 [3] my_custom_train!(::typeof(loss), ::Params, ::Base.Generator{UnitRange{Int64},var"#13#14"}, ::ADAM) at .\REPL[33]:21
 [4] top-level scope at .\util.jl:175

Do anyone have some clue for this? Thanks a lot!

I think the problem is in your logging statement, gs[1] is asking for the gradient with respect to 1, you need to request the gradient with respect to a parameter, for example gs[ps[1]].

Thank you! Changing gs[1] to gs[ps[1]] works!