I am making a custom RNN layer based on the code in recurrent.jl in the Flux code base.
However, when I try to train the model using the task given in the character rnn model zoo example (char-rnn.jl), I get an error saying ``back! was already used
Any idea what might be going on?
Here’s the code:
using Flux
using Flux.Tracker: param, back!, grad
using Flux: onehot, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition
mutable struct Recur{T}
cell::T
init
state
end
Recur(m, h = hidden(m)) = Recur(m, h, h)
function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return y
end
Flux.@treelike Recur cell, init
reset!(m::Recur) = (m.state = m.init)
# INITIALISATION
glorot_uniform(dims...) = (rand(Float64, dims...) .- 0.5) .* sqrt(24.0/sum(dims))
glorot_normal(dims...) = randn(Float64, dims...) .* sqrt(2.0/sum(dims))
gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
mutable struct custom_rnn_cell{A,V}
Wi::A
Wh::A
b::V
h::V
end
custom_rnn_cell(in, out; init = glorot_normal) =
custom_rnn_cell(init(out * 3, in), init(out * 3, out),
init(out * 3), zeros(out))
function (m::custom_rnn_cell)(h, x)
b, o = m.b, size(h, 1)
gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h′ = (1 .- z).*h̃ .+ z.*h
return h′, h′
end
hidden(m::custom_rnn_cell) = m.h
Flux.@treelike custom_rnn_cell
custom_rnn(a...; ka...) = Recur(custom_rnn_cell(a...; ka...))
### ---------------------------- CHARACTER RNN EXAMPLE FROM MODEL ZOO
cd(@__DIR__)
isfile("input.txt") ||
download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
"input.txt")
text = collect(String(read("input.txt")))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)
N = length(alphabet)
seqlen = 50
nbatch = 50
Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen))
custom_RNN = custom_rnn(128, 128);
chained_model = Chain(
Dense(N, 128),
gated_rnn,
Dense(128, N),
softmax)
net_params = params(chained_model);
function loss(xs, ys)
l = sum(crossentropy.(chained_model.(xs), ys))
Flux.truncate!(chained_model)
#Flux.truncate!(gated_rnn.cell)
return l
end
opt = ADAM(0.01)
Flux.train!(loss, net_params, zip(Xs, Ys), opt)
# this works
# l = loss(Xs[1],Ys[1]); back!(l)
# but this fails after the line above
# l = loss(Xs[2],Ys[2]); back!(l)