Flux new explicit API not work but old implicit API works for a simple RNN

I am trying to reproduce the tutorial A Basic RNN using Flux.jl v0.14.6. Using the old Flux API as in the tutorial, the model can be successfully trained. The code is

using Flux

num_samples = 1000
num_epochs = 50

function generate_data(num_samples)
    train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in 1:num_samples]
    train_labels = (v -> sum(v)).(train_data)

    test_data = 2 .* train_data
    test_labels = 2 .* train_labels

    train_data, train_labels, test_data, test_labels
end

train_data, train_labels, test_data, test_labels = generate_data(num_samples)

model = Flux.RNN(2, 1, (x -> x))

function eval_model(x)
    Flux.reset!(model)
    out = [model(view(x, :, t)) for t in axes(x, 2)]
    out[end] |> first
end

loss(x, y) = abs(sum(eval_model(x) .- y))

evalcb() = @show(sum(loss.(test_data, test_labels)))

ps = Flux.params(model)

opt = Flux.ADAM(0.1)

for epoch in 1:num_epochs
    Flux.train!(loss, ps, zip(train_data, train_labels), opt, cb = Flux.throttle(evalcb, 1))
end

However, refractor the above code to use the new explicit API, Zygote complains.

ERROR: LoadError: MethodError: no method matching +(::@NamedTuple{cell::@NamedTuple{σ::Nothing, Wi::Matrix{Float32}, Wh::Matrix{Float32}, b::Vector{Float32}, state0::Nothing}, state::Matrix{Float32}}, ::Base.RefValue{Any})

The code is as follows

using Flux
using Statistics

num_samples = 1000
num_epochs = 50

function generate_data(num_samples)
    train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in 1:num_samples]
    train_labels = (v -> sum(v)).(train_data)

    test_data = 2 .* train_data
    test_labels = 2 .* train_labels

    train_data, train_labels, test_data, test_labels
end

train_data, train_labels, test_data, test_labels = generate_data(num_samples)

model = Flux.RNN(2, 1, (x -> x))

function eval_model(model, x)
    # Comment following line to make it run.
    # However, in the Flux doc, the following line is required.
    Flux.reset!(model)
    out = [model(view(x, :, t)) for t in axes(x, 2)]
    out[end] |> first
end

loss(model, x, y) = abs(sum(eval_model(model, x) .- y))

opt_state = Flux.setup(Flux.ADAM(0.1), model)

for epoch in 1:num_epochs
    for (x, y) in zip(train_data, train_labels)
        train_loss, grads = Flux.withgradient(model) do m
            loss(m, x, y)
        end
        Flux.update!(opt_state, model, grads[1])
    end
    test_loss = mean(loss.(Ref(model), test_data, test_labels))
    println("Epoch $epoch, loss = $test_loss")
end

# Following codes also failed to run.
# for epoch in 1:num_epochs
#     Flux.train!(model, zip(train_data, train_labels), opt_state) do m, x, y
#         loss(m, x, y)
#     end
# end

When I comment the line Flux.reset!(model), the code can run. But this contradict the requirement of Flux.RNN docs which states that the model should be reset before computing loss.

Anyone know how to make the new Flux API work for RNNs?

1 Like