Hello there!
I am currently trying to implement a word-2-word Autoencoder with the Flux.jl package.
It went mostly well, but I’m running into some problems now and I hope someone might be able to nudge me into the right direction. Specifically, I have trouble with the gradient calculation and implementing the logic of the decoder and the steps that happen between the decoder & encoder.
Alright here goes: I have an Encoder and a Decoder
encoderRNN = Chain(Flux.Embedding(43, 256), GRU(256, 256))
decoderRNN = Chain(Flux.Embedding(43, 256), x -> relu.(x), GRU(256, 256), Dense(256, 43))
My problem is that I want to train these together, but I can’t get it to work, probably because I’m not thoroughly understanding how to implement the gradient calculation in my training loop.
Here’s my outer training loop, where I tried to kinda stick to the custom training loop examples in the documentation (it’s not a MWE because of the input data; mostly just trying to show the structure of the code).
function trainIters(encoder, decoder, n_iters; learning_rate=0.01)
optimizer = Descent(learning_rate)
local loss
ps = Flux.params(encoder, decoder)
# training data is usually called from another function here
training_pairs = [(input1, target1), (input2, target2), ...,]
# each iteration corresponds to one pair of words in the training data
for iter in 1:n_iters
training_pair = training_pairs[iter]
input = training_pair[1]
target = training_pair[2]
gs = gradient(ps) do
loss = train(input, target, encoder, decoder)
return loss
end # do
Flux.Optimise.update!(optimizer, ps, gs)
end # for
end # trainIters
Now in all the examples I’ve seen, the loss is usually computed along the lines of loss(x,y) = SomeLossFunction(m(x), y) . But since there are additional steps I have to take, that are not represented in the two Chains (Encoder, Decoder), I pass more arguments to my train function, which looks like this (I promise this is the last piece of code):
function train(input, target, encoder, decoder; max_length = 25)::Float64
target_length = length(target)
outputs::Vector{Matrix{Float32}} = []
targets::Vector{Int64} = []
# encoder part; easy enough
for letter in input
encoder(letter)
end # for
# set input and hidden state of decoder
decoder_input::Int64 = SOS_token
decoder[3].state = encoder[2].state
use_teacher_forcing = rand() < teacher_forcing_ratio ? true : false
if use_teacher_forcing
# Teacher forcing: Feed the target as the next input
for i in 1:target_length
output = decoder(decoder_input)
Zygote.ignore() do
push!(outputs, output[:, :])
push!(targets, target[i])
decoder_input = target[i] # Teacher forcing
end # do
end # for
else
for i in 1:target_length
output = decoder(decoder_input)
topv, topi = findmax(output)
Zygote.ignore() do
push!(outputs, output[:, :])
push!(targets, target[i])
end # do
decoder_input = topi
if topi == EOS_token
Zygote.ignore() do
outputs = outputs[1:i]
targets = targets[1:i]
end # do
break
end # if
end # for
end # if/else
output_matrix::AbstractMatrix = hcat(outputs...)
onehot_targets = Flux.onehotbatch(targets, 1:(input_lang.n_words - 1))
loss::Float64 = Flux.logitcrossentropy(output_matrix, onehot_targets)
return loss
As you can see here I need to modify & update the input & hidden state of the decoder and swap between teacher forcing and no teacher forcing. You can also see that I’ve frantically tried to get the gradient to work with Zygote.ignore() statements that are probably terribly wrong. My question would be: Can I - if yes how - introduce the logic of how the data is handled between the decoder and encoder directly into my model, or is there any way I can get the gradient to work with this current structure?
I know this is a long and maybe a bit convoluted post, so thank you for anyone taking the time to look at this!