(cross post from zulip)
I’m a recent ML and flux learner currently trying to learn and translate GitHub - karpathy/makemore: An autoregressive character-level language model for making more things to Flux. I finished translation of MLP and now I’m studying RNNs. I want to build it from scratch using this torch code as a guideline:
class RNNCell(nn.Module):
"""
the job of a 'Cell' is to:
take input at current time step x_{t} and the hidden state at the
previous time step h_{t-1} and return the resulting hidden state
h_{t} at the current timestep
"""
def __init__(self, config):
super().__init__()
self.xh_to_h = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
def forward(self, xt, hprev):
xh = torch.cat([xt, hprev], dim=1)
ht = F.tanh(self.xh_to_h(xh))
return ht
class RNN(nn.Module):
def __init__(self, config, cell_type):
super().__init__()
self.block_size = config.block_size
self.vocab_size = config.vocab_size
self.start = nn.Parameter(
torch.zeros(1, config.n_embd2)
) # the starting hidden state
self.wte = nn.Embedding(
config.vocab_size, config.n_embd
) # token embeddings table
if cell_type == "rnn":
self.cell = RNNCell(config)
elif cell_type == "gru":
self.cell = GRUCell(config)
self.lm_head = nn.Linear(config.n_embd2, self.vocab_size)
def get_block_size(self):
return self.block_size
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
# embed all the integers up front and all at once for efficiency
emb = self.wte(idx) # (b, t, n_embd)
# sequentially iterate over the inputs and update the RNN state each tick
hprev = self.start.expand((b, -1)) # expand out the batch dimension
hiddens = []
for i in range(t):
xt = emb[:, i, :] # (b, n_embd)
ht = self.cell(xt, hprev) # (b, n_embd2)
hprev = ht
hiddens.append(ht)
# decode the outputs
hidden = torch.stack(hiddens, 1) # (b, t, n_embd2)
logits = self.lm_head(hidden)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
)
return logits, loss
The most important bit is the forward pass, in torch there are list appends and opposite index order, so I’m probably making a mistake there as the model is not training at all being stuck at 0.3 loss. Here is my translation of the relevant part to Flux, and there must some weird mistake that I cannot figure out.
function RNNCell(config)
xh2h = Dense(config.nembedding + config.nembedding2, config.nembedding2)
return RNNCell(config, xh2h)
end
function (m::RNNCell)(x, h)
xh = vcat(x, h)
ht = Flux.tanh.(m.xh2h(xh))
return ht
end
Base.@kwdef struct RNN{T<:Integer}
blocksize::T
vocabsize::T
start::Matrix{Float32}
wte::Embedding
cell::Union{RNNCell,GRUCell}
lmhead::Dense
config::Config
end
function RNN(config)
emb = Embedding(config.vocabsize, config.nembedding)
return RNN(
config.blocksize,
config.vocabsize,
zeros(Float32, config.nembedding2, 1),
emb,
RNNCell(config),
Dense(config.nembedding2, config.vocabsize),
config
)
end
Flux.@functor RNNCell (xh2h,)
Flux.@functor RNN (cell,)
# The problem is likely in this function
# most probably I'm making a mistake using Zygote.Buffer
function (m::RNN)(index)
emb = m.wte(index) # (n_emb, t, b) opposite to torch
t, b = size(index)
# Somebody will make this better (better to avoid copying)
hprev = hcat([m.start for i = 1:b]...)
hiddens = Flux.Zygote.Buffer(emb, m.cell.config.nembedding2, t, b)
for i in 1:t
xt = emb[:, i, :]
ht = m.cell(xt, hprev)
hprev = ht
hiddens[:, i, :] .= hprev
end
hidden = copy(hiddens) # t, n_emb, b
logits = m.lmhead(hidden)
return logits
end