Custom RNN with Flux gives an error with Flux.train!

I am making a custom RNN layer based on the code in recurrent.jl in the Flux code base.
However, when I try to train the model using the task given in the character rnn model zoo example (char-rnn.jl), I get an error saying ``back! was already used
Any idea what might be going on?
Here’s the code:

using Flux
using Flux.Tracker: param, back!, grad
using Flux: onehot, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition

mutable struct Recur{T}
  cell::T
  init
  state
end

Recur(m, h = hidden(m)) = Recur(m, h, h)

function (m::Recur)(xs...)
  h, y = m.cell(m.state, xs...)
  m.state = h
  return y
end

Flux.@treelike Recur cell, init

reset!(m::Recur) = (m.state = m.init)

# INITIALISATION
glorot_uniform(dims...) = (rand(Float64, dims...) .- 0.5) .* sqrt(24.0/sum(dims))
glorot_normal(dims...) = randn(Float64, dims...) .* sqrt(2.0/sum(dims))
  
gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]

mutable struct custom_rnn_cell{A,V}
  Wi::A
  Wh::A
  b::V
  h::V
end

custom_rnn_cell(in, out; init = glorot_normal) =
  custom_rnn_cell(init(out * 3, in), init(out * 3, out),
          init(out * 3), zeros(out))

function (m::custom_rnn_cell)(h, x)
  b, o = m.b, size(h, 1)
  gx, gh = m.Wi*x, m.Wh*h
  r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
  z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
  h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
  h′ = (1 .- z).*h̃ .+ z.*h
  return h′, h′
end

hidden(m::custom_rnn_cell) = m.h

Flux.@treelike custom_rnn_cell

custom_rnn(a...; ka...) = Recur(custom_rnn_cell(a...; ka...))



### ---------------------------- CHARACTER RNN EXAMPLE FROM MODEL ZOO

cd(@__DIR__)

isfile("input.txt") ||
  download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
           "input.txt")

text = collect(String(read("input.txt")))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)

N = length(alphabet)
seqlen = 50
nbatch = 50

Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen))

custom_RNN = custom_rnn(128, 128);

chained_model = Chain(
  Dense(N, 128),  
  gated_rnn,
  Dense(128, N),
  softmax)

net_params = params(chained_model);


function loss(xs, ys)
  l = sum(crossentropy.(chained_model.(xs), ys))
  Flux.truncate!(chained_model)
  #Flux.truncate!(gated_rnn.cell)
  return l
end


opt = ADAM(0.01)

Flux.train!(loss, net_params, zip(Xs, Ys), opt)

# this works
# l = loss(Xs[1],Ys[1]); back!(l)
# but this fails after the line above
#  l = loss(Xs[2],Ys[2]); back!(l)