I am trying to reproduce the tutorial A Basic RNN using Flux.jl v0.14.6. Using the old Flux API as in the tutorial, the model can be successfully trained. The code is
using Flux
num_samples = 1000
num_epochs = 50
function generate_data(num_samples)
train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in 1:num_samples]
train_labels = (v -> sum(v)).(train_data)
test_data = 2 .* train_data
test_labels = 2 .* train_labels
train_data, train_labels, test_data, test_labels
end
train_data, train_labels, test_data, test_labels = generate_data(num_samples)
model = Flux.RNN(2, 1, (x -> x))
function eval_model(x)
Flux.reset!(model)
out = [model(view(x, :, t)) for t in axes(x, 2)]
out[end] |> first
end
loss(x, y) = abs(sum(eval_model(x) .- y))
evalcb() = @show(sum(loss.(test_data, test_labels)))
ps = Flux.params(model)
opt = Flux.ADAM(0.1)
for epoch in 1:num_epochs
Flux.train!(loss, ps, zip(train_data, train_labels), opt, cb = Flux.throttle(evalcb, 1))
end
However, refractor the above code to use the new explicit API, Zygote complains.
ERROR: LoadError: MethodError: no method matching +(::@NamedTuple{cell::@NamedTuple{σ::Nothing, Wi::Matrix{Float32}, Wh::Matrix{Float32}, b::Vector{Float32}, state0::Nothing}, state::Matrix{Float32}}, ::Base.RefValue{Any})
The code is as follows
using Flux
using Statistics
num_samples = 1000
num_epochs = 50
function generate_data(num_samples)
train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in 1:num_samples]
train_labels = (v -> sum(v)).(train_data)
test_data = 2 .* train_data
test_labels = 2 .* train_labels
train_data, train_labels, test_data, test_labels
end
train_data, train_labels, test_data, test_labels = generate_data(num_samples)
model = Flux.RNN(2, 1, (x -> x))
function eval_model(model, x)
# Comment following line to make it run.
# However, in the Flux doc, the following line is required.
Flux.reset!(model)
out = [model(view(x, :, t)) for t in axes(x, 2)]
out[end] |> first
end
loss(model, x, y) = abs(sum(eval_model(model, x) .- y))
opt_state = Flux.setup(Flux.ADAM(0.1), model)
for epoch in 1:num_epochs
for (x, y) in zip(train_data, train_labels)
train_loss, grads = Flux.withgradient(model) do m
loss(m, x, y)
end
Flux.update!(opt_state, model, grads[1])
end
test_loss = mean(loss.(Ref(model), test_data, test_labels))
println("Epoch $epoch, loss = $test_loss")
end
# Following codes also failed to run.
# for epoch in 1:num_epochs
# Flux.train!(model, zip(train_data, train_labels), opt_state) do m, x, y
# loss(m, x, y)
# end
# end
When I comment the line Flux.reset!(model)
, the code can run. But this contradict the requirement of Flux.RNN docs which states that the model should be reset before computing loss.
Anyone know how to make the new Flux API work for RNNs?