Translating from torch to Flux

(cross post from zulip)
I’m a recent ML and flux learner currently trying to learn and translate GitHub - karpathy/makemore: An autoregressive character-level language model for making more things to Flux. I finished translation of MLP and now I’m studying RNNs. I want to build it from scratch using this torch code as a guideline:

class RNNCell(nn.Module):
    """
    the job of a 'Cell' is to:
    take input at current time step x_{t} and the hidden state at the
    previous time step h_{t-1} and return the resulting hidden state
    h_{t} at the current timestep
    """

    def __init__(self, config):
        super().__init__()
        self.xh_to_h = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)

    def forward(self, xt, hprev):
        xh = torch.cat([xt, hprev], dim=1)
        ht = F.tanh(self.xh_to_h(xh))
        return ht
class RNN(nn.Module):
    def __init__(self, config, cell_type):
        super().__init__()
        self.block_size = config.block_size
        self.vocab_size = config.vocab_size
        self.start = nn.Parameter(
            torch.zeros(1, config.n_embd2)
        )  # the starting hidden state
        self.wte = nn.Embedding(
            config.vocab_size, config.n_embd
        )  # token embeddings table
        if cell_type == "rnn":
            self.cell = RNNCell(config)
        elif cell_type == "gru":
            self.cell = GRUCell(config)
        self.lm_head = nn.Linear(config.n_embd2, self.vocab_size)

    def get_block_size(self):
        return self.block_size

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()

        # embed all the integers up front and all at once for efficiency
        emb = self.wte(idx)  # (b, t, n_embd)

        # sequentially iterate over the inputs and update the RNN state each tick
        hprev = self.start.expand((b, -1))  # expand out the batch dimension
        hiddens = []
        for i in range(t):
            xt = emb[:, i, :]  # (b, n_embd)
            ht = self.cell(xt, hprev)  # (b, n_embd2)
            hprev = ht
            hiddens.append(ht)

        # decode the outputs
        hidden = torch.stack(hiddens, 1)  # (b, t, n_embd2)
        logits = self.lm_head(hidden)

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
            )

        return logits, loss

The most important bit is the forward pass, in torch there are list appends and opposite index order, so I’m probably making a mistake there as the model is not training at all being stuck at 0.3 loss. Here is my translation of the relevant part to Flux, and there must some weird mistake that I cannot figure out.

function RNNCell(config)
    xh2h = Dense(config.nembedding + config.nembedding2, config.nembedding2)
    return RNNCell(config, xh2h)
end

function (m::RNNCell)(x, h)
    xh = vcat(x, h)
    ht = Flux.tanh.(m.xh2h(xh))
    return ht
end

Base.@kwdef struct RNN{T<:Integer}
    blocksize::T
    vocabsize::T
    start::Matrix{Float32}
    wte::Embedding
    cell::Union{RNNCell,GRUCell}
    lmhead::Dense
    config::Config
end

function RNN(config)
    emb = Embedding(config.vocabsize, config.nembedding)
    return RNN(
        config.blocksize,
        config.vocabsize,
        zeros(Float32, config.nembedding2, 1),
        emb,
        RNNCell(config),
        Dense(config.nembedding2, config.vocabsize),
        config
    )
end

Flux.@functor RNNCell (xh2h,)
Flux.@functor RNN (cell,)

# The problem is likely in this function
# most probably I'm making a mistake using Zygote.Buffer

function (m::RNN)(index)
    emb = m.wte(index) # (n_emb, t, b) opposite to torch
    t, b = size(index)
    # Somebody will make this better (better to avoid copying)
    hprev = hcat([m.start for i = 1:b]...)
    hiddens = Flux.Zygote.Buffer(emb, m.cell.config.nembedding2, t, b)
    for i in 1:t
        xt = emb[:, i, :]
        ht = m.cell(xt, hprev)
        hprev = ht
        hiddens[:, i, :] .= hprev
    end
    hidden = copy(hiddens) # t, n_emb, b
    logits = m.lmhead(hidden)
    return logits
end

Quick thing to try: change hiddens[:, i, :] .= hprev to hiddens[:, i, :] = hprev.

Less quick but still quick thing to try: append to the buffer instead of writing to it. This requires some rejigging of dimensions to better match Julia conventions.

function (m::RNN)(index) # index has size (b, t)
    emb = m.wte(index) # (n_emb, b, t) still opposite to torch?
    b, t = size(index)

    # shouldn't m.start be a vector?
    hprev = repeat(m.start, 1, b) # (x, y) -> (x, y * b)

    hiddens = Flux.Zygote.Buffer([])
    for i in 1:t
        xt = @view emb[:, :, t] # using a view is faster during inference
        ht = m.cell(xt, hprev)
        hprev = ht
        push!(hiddens, hprev)
    end

    # If you're using Julia <1.9, you can get stack() from the Compat.jl package
    hidden = stack(copy(hiddens)) # ((n_emb, b), t) -> (n_emb, b, t)
    logits = m.lmhead(hidden)
    return logits
end

(Note: untested)

1 Like

Thanks, I actually resolved this on my own eventually. For those who might be wondering, changing hiddens[:, i, :] .= hprev to hiddens[:, i, :] = hprev should be done, just to get it working (as it is the only way Flux supports). I tried both ways before, but for some reason posted the wrong initial code to begin with.
I made a rookie mistake of not declaring RNN variables “learnable”:
Flux.@functor RNN (wte, cell, lmhead,) should be there instead of Flux.@functor RNN (cell,).

1 Like