I am trying to write a custom training function instead of using Flux!train
, following the documentation https://fluxml.ai/Flux.jl/stable/training/training/#Model-parameters-1, but it will pop up error message “Only reference types can be differentiated with Params
Here is my code (modified based on trebuchet example in model-zoo):
using Flux
using Zygote
using Statistics
using Random
function shoot( angle, weight)
model = Chain(Dense(1, 16, σ),
Dense(16, 64, σ),
Dense(64, 16, σ),
Dense(16, 2)) |> f64
θ = params(model)
function loss( target)
angle, weight = model([target])
angle = σ(angle)*90
weight = weight + 200
(shoot( angle, weight) - target)^2
DIST = (20, 100) # Maximum target distance
target() = (rand()*(DIST[2]-DIST[1])+DIST[1])
meanloss() = mean(sqrt(loss(target())) for i = 1:100)
opt = ADAM()
dataset = (target() for i = 1:2000)
@time Flux.train!(loss, θ, dataset, opt, cb = () -> println("meanloss = ",meanloss(),"; W1 = ",θ.order.data[1][1],"; b1 = ",θ.order.data[2][1]))
It can work by now. Then I wrote a custom training function and run the training:
function my_custom_train!(loss, ps, data, opt)
local training_loss
ps = Params(ps)
for d in data
gs = gradient(ps) do
training_loss = loss(d...)
return training_loss
println("training_loss = ",training_loss,"gradient[1] = ",gs[1],"; W1 = ",ps.order.data[1][1],"; b1 = ",ps.order.data[2][1])
Flux.update!(opt, ps, gs)
@time my_custom_train!(loss, θ, dataset, opt)
I got the following error:
ERROR: Only reference types can be differentiated with `Params`.
[1] error(::String) at .\error.jl:33
[2] getindex(::Zygote.Grads, ::Int64) at C:\Users\maxiao\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:142
[3] my_custom_train!(::typeof(loss), ::Params, ::Base.Generator{UnitRange{Int64},var"#13#14"}, ::ADAM) at .\REPL[33]:21
[4] top-level scope at .\util.jl:175
Do anyone have some clue for this? Thanks a lot!