Question about writting custom training function using Flux.jl

Hello,

I am trying to write a custom training function instead of using Flux!train, following the documentation https://fluxml.ai/Flux.jl/stable/training/training/#Model-parameters-1, but it will pop up error message "Only reference types can be differentiated with Params".

Here is my code (modified based on trebuchet example in model-zoo):

using Flux
using Zygote
using Statistics
using Random

function shoot( angle, weight)
  angle/weight*10
end

Random.seed!(0)

model = Chain(Dense(1, 16, σ),
              Dense(16, 64, σ),
              Dense(64, 16, σ),
              Dense(16, 2)) |> f64

θ = params(model)

function loss( target)
    angle, weight = model([target]) 
    angle = σ(angle)*90
    weight = weight + 200
    (shoot( angle, weight) - target)^2
end

DIST  = (20, 100)	# Maximum target distance

target() = (rand()*(DIST[2]-DIST[1])+DIST[1])

meanloss() = mean(sqrt(loss(target())) for i = 1:100)

opt = ADAM()
dataset = (target() for i = 1:2000)

@time Flux.train!(loss, θ, dataset, opt, cb = () -> println("meanloss = ",meanloss(),"; W1 = ",θ.order.data[1][1],"; b1 = ",θ.order.data[2][1]))

It can work by now. Then I wrote a custom training function and run the training:


function my_custom_train!(loss, ps, data, opt)
  local training_loss
  ps = Params(ps)
  for d in data
    gs = gradient(ps) do
      training_loss = loss(d...)
      return training_loss
    end
    println("training_loss = ",training_loss,"gradient[1] = ",gs[1],"; W1 = ",ps.order.data[1][1],"; b1 = ",ps.order.data[2][1])
    Flux.update!(opt, ps, gs)
  end
end

@time my_custom_train!(loss, θ, dataset, opt)

I got the following error:

ERROR: Only reference types can be differentiated with `Params`.
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] getindex(::Zygote.Grads, ::Int64) at C:\Users\maxiao\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:142
 [3] my_custom_train!(::typeof(loss), ::Params, ::Base.Generator{UnitRange{Int64},var"#13#14"}, ::ADAM) at .\REPL[33]:21
 [4] top-level scope at .\util.jl:175

Do anyone have some clue for this? Thanks a lot!

I think the problem is in your logging statement, gs[1] is asking for the gradient with respect to 1, you need to request the gradient with respect to a parameter, for example gs[ps[1]].

1 Like

Thank you! Changing gs[1] to gs[ps[1]] works!