It mostly works now, even though it converges a little slower than the python implementation.
Changed this in the train_iters function:
loss, back = Flux.pullback(ps) do
train(input, target, encoder, decoder)
end # do
grad = back(1f0)
And this is the updated loss function:
function train(input, target, encoder, decoder; max_length = MAX_LENGTH)::Float64
encoder[2].state = reshape(zeros(hidden_size), hidden_size, 1)
target_length = length(target)
range::UnitRange = 1:(input_lang.n_words - 1)
loss::Float64 = 0.0
for letter in input
encoder(letter)
end # for
decoder_input::Int64 = SOS_token
decoder[3].state = encoder[2].state
use_teacher_forcing = rand() < teacher_forcing_ratio ? true : false
if use_teacher_forcing
# Teacher forcing: Feed the target as the next input
for i in 1:target_length
output = decoder(decoder_input)
onehot_target = Flux.onehot(target[i], range)
loss += Flux.logitcrossentropy(output, onehot_target)
Zygote.ignore() do
decoder_input = target[i]
end # do
end # for
else
for i in 1:target_length
output = decoder(decoder_input)
topv, topi = findmax(output)
onehot_target = Flux.onehot(target[i], range)
loss += Flux.logitcrossentropy(output, onehot_target)
Zygote.ignore() do
decoder_input = topi
end # do
topi == EOS_token && break
end # for
end # if/else
return loss / target_length
end # model_loss