Problems regarding an implementation of a word2word Autoencoder

Hello there!

I am currently trying to implement a word-2-word Autoencoder with the Flux.jl package.

It went mostly well, but I’m running into some problems now and I hope someone might be able to nudge me into the right direction. Specifically, I have trouble with the gradient calculation and implementing the logic of the decoder and the steps that happen between the decoder & encoder.

Alright here goes: I have an Encoder and a Decoder

encoderRNN = Chain(Flux.Embedding(43, 256), GRU(256, 256))
decoderRNN = Chain(Flux.Embedding(43, 256), x -> relu.(x), GRU(256, 256), Dense(256, 43))

My problem is that I want to train these together, but I can’t get it to work, probably because I’m not thoroughly understanding how to implement the gradient calculation in my training loop.

Here’s my outer training loop, where I tried to kinda stick to the custom training loop examples in the documentation (it’s not a MWE because of the input data; mostly just trying to show the structure of the code).

function trainIters(encoder, decoder, n_iters; learning_rate=0.01)

    optimizer = Descent(learning_rate)
    local loss
    ps = Flux.params(encoder, decoder)
    # training data is usually called from another function here
    training_pairs = [(input1, target1), (input2, target2), ...,]
    # each iteration corresponds to one pair of words in the training data
    for iter in 1:n_iters
        training_pair = training_pairs[iter]
        input = training_pair[1]
        target = training_pair[2]
         
        gs = gradient(ps) do 
            loss = train(input, target, encoder, decoder)  
            return loss
        end # do

        Flux.Optimise.update!(optimizer, ps, gs)
    end # for
end # trainIters

Now in all the examples I’ve seen, the loss is usually computed along the lines of loss(x,y) = SomeLossFunction(m(x), y) . But since there are additional steps I have to take, that are not represented in the two Chains (Encoder, Decoder), I pass more arguments to my train function, which looks like this (I promise this is the last piece of code):

function train(input, target, encoder, decoder; max_length = 25)::Float64

    target_length = length(target)
    outputs::Vector{Matrix{Float32}} = []
    targets::Vector{Int64} = []

    # encoder part; easy enough
    for letter in input
        encoder(letter)
    end # for

     # set input and hidden state of decoder
    decoder_input::Int64 = SOS_token
    decoder[3].state = encoder[2].state

    use_teacher_forcing = rand() < teacher_forcing_ratio ? true : false

    if use_teacher_forcing
        # Teacher forcing: Feed the target as the next input
        for i in 1:target_length
            output = decoder(decoder_input)
            Zygote.ignore() do
                push!(outputs, output[:, :])
                push!(targets, target[i])
            decoder_input = target[i]  # Teacher forcing
            end # do
        end # for
    else
        for i in 1:target_length
            output = decoder(decoder_input)
            topv, topi = findmax(output)
            Zygote.ignore() do
                push!(outputs, output[:, :])
                push!(targets, target[i])
            end # do
            decoder_input = topi
            if topi == EOS_token
                Zygote.ignore() do
                    outputs = outputs[1:i]
                    targets = targets[1:i]
                end # do
                break
            end  # if
        end # for
    end # if/else
    output_matrix::AbstractMatrix = hcat(outputs...)
    onehot_targets = Flux.onehotbatch(targets, 1:(input_lang.n_words - 1))
    loss::Float64 = Flux.logitcrossentropy(output_matrix, onehot_targets)
    return loss 

As you can see here I need to modify & update the input & hidden state of the decoder and swap between teacher forcing and no teacher forcing. You can also see that I’ve frantically tried to get the gradient to work with Zygote.ignore() statements that are probably terribly wrong. My question would be: Can I - if yes how - introduce the logic of how the data is handled between the decoder and encoder directly into my model, or is there any way I can get the gradient to work with this current structure?

I know this is a long and maybe a bit convoluted post, so thank you for anyone taking the time to look at this!

What is the error you get?

Not getting an error per se, as the code finishes without an error message. But the loss isn’t converging during training, which I think stems from the gradient calculation not working correctly the way I set it up.

It mostly works now, even though it converges a little slower than the python implementation.

Changed this in the train_iters function:

loss, back = Flux.pullback(ps) do 
            train(input, target, encoder, decoder)  
end # do
        
grad = back(1f0)

And this is the updated loss function:

function train(input, target, encoder, decoder; max_length = MAX_LENGTH)::Float64

    encoder[2].state = reshape(zeros(hidden_size), hidden_size, 1)

    target_length = length(target)

    range::UnitRange = 1:(input_lang.n_words - 1)

    loss::Float64 = 0.0

    for letter in input
        encoder(letter)
    end # for

    decoder_input::Int64 = SOS_token
    decoder[3].state = encoder[2].state

    use_teacher_forcing = rand() < teacher_forcing_ratio ? true : false

    if use_teacher_forcing
        # Teacher forcing: Feed the target as the next input
        for i in 1:target_length
            output = decoder(decoder_input)
            onehot_target = Flux.onehot(target[i], range)
            loss += Flux.logitcrossentropy(output, onehot_target)
            Zygote.ignore() do 
                decoder_input = target[i] 
            end # do
        end # for

    else
        for i in 1:target_length
            output = decoder(decoder_input)
            topv, topi = findmax(output)
            onehot_target = Flux.onehot(target[i], range)
            loss += Flux.logitcrossentropy(output, onehot_target)
            Zygote.ignore() do 
                decoder_input = topi 
            end # do
            topi == EOS_token && break
        end # for
    end # if/else
    return loss / target_length
end # model_loss
1 Like

Great that you got it working. As for the convergence question, the default choice of hyperparams is not necessarily consistent across Flux and PyTorch. Initialization is the big one that comes to mind, but optimizer params may be worth a look too.

1 Like

Thank you for the follow up. I have in the meantime updated the initalization as well to mirror the python code:

encoderRNN = Chain(Flux.Embedding(input_lang.n_words - 1, hidden_size), 
                   GRU(hidden_size, hidden_size, init=Flux.kaiming_uniform))

decoderRNN = Chain(Flux.Embedding(output_lang.n_words - 1, hidden_size), x -> relu.(x), 
                   GRU(hidden_size, hidden_size, init=Flux.kaiming_uniform), 
                   Dense(hidden_size, output_lang.n_words - 1, init=Flux.kaiming_uniform)
                   )

The Julia implementation is still converging a bit slower, but I’m mostly happy that it’s finally running.

The python code has snippets like this…

decoder_input = topi.squeeze().detach()

… and I was thinking that maybe I need to exclude the corresponding lines in the Julia code from the gradient calculation like this:

Zygote.ignore() do
    decoder_input = target[i]
end

But even then the convergence is still slower, so maybe it is like you said, and I actually have to take a closer look at the difference between Gradient() in Julia and the optim.SGD() in Python.

By slower, do you mean number of steps taken or wall clock time? If the former, then it’s probably something algorithmic (e.g. initializations, which you’ve already checked). If it’s the latter, then I’d love to see some comparative timings and/or profiling info :slight_smile:

One more note: optim.SGD == Flux.Optim.Descent. If you’re using momentum in your PyTorch code, then you’ll want to use Momentum or Nesterov instead.

Regarding the speed: Both the number of steps & the wall clock time actually. Though I was mainly talking about the number of steps. When I get back to working on this problem sometime this week, I’ll try some timings.

About the 2nd part: I apparently totally blanked in my last post. It shouldn’t say Gradient(), but Descent() - which is what I’ve used. So I guess there is still some algorithmic difference left to find for me, since both the optimiser & the initialization are the same in both versions now.

1 Like