Understanding GRU forward pass as implemented in Knet

Hello all,

I tried to reimplement the simple forward pass of a GRU layer in order to be able to evaluate the RNN in other languages after training in Julia. So far with no luck, and Im starting to be out of ideas. Is there maybe someone who had a similar problem in the past and could give me a hint? Or someone who has even implemented this already?

I based my implementation on the documentation of Knet.RNN and Knet.rnnparam, but the results over a single timestep are already different. Here is my minimal testing code:

using Knet

# Define a recurrent layer in Knet
nX = 2 # Number of inputs
nH = 3 # Number of hidden states
knet_gru = RNN(nX, nH; rnnType = :gru, dataType = Float32) # recurrent layer (gru) implementation in Knet
rnn_params = rnnparams(knet_gru) # extract the params

# Define forward pass of gru again, using information of "@doc RNN" and "@doc rnnparam"
function my_gru(Wr, Wi, Wn, Rr, Ri, Rn, bWr, bWi, bWn, bRr, bRi, bRn, x, h_in)
    r = sigm.(Wr' * x .+ Rr' * h_in .+ bWr .+ bRr) # reset gate
    i = sigm.(Wi' * x .+ Ri' * h_in .+ bWi .+ bRi) # input gate
    n = tanh.(Wn' * x .+ r .* (Rn' * h_in .+ bRn) .+ bWn) # new gate
    h_out = (1 .- i) .* n .+ i .* h_in
    return h_out
end

my_gru(W, x, h) = my_gru(W..., x, h)

# Compare both
x = randn(Float32, nX)
h = randn(Float32, nH)

knet_gru.h = h #set starting state for Knet RNN
res1 = knet_gru(x)
res2 = my_gru(rnn_params, x, h)
print("Difference between Knet and my_gru: ")
println(res1-res2)

Hello all,
After digging into the Knet code on Github, I was able to find the missing piece: the hidden state h has to be defined as a 3-D tensor, or it will be set to nothing for the evaluation without a warning.
A working example is thus:

using Knet
Knet.gpu(false)

# Define a recurrent layer in Knet
nX = 2 # Number of inputs
nH = 3 # Number of hidden states
knet_gru = RNN(nX, nH; rnnType = :gru, dataType = Float32) # recurrent layer (gru) implementation in Knet
rnn_params = rnnparams(knet_gru) # extract the params

# Define forward pass of gru again, using information of "@doc RNN" and "@doc rnnparam"
function my_gru(Wr, Wi, Wn, Rr, Ri, Rn, bWr, bWi, bWn, bRr, bRi, bRn, x, h_in)
    r = sigm.(Wr' * x .+ Rr' * vec(h_in) .+ bWr .+ bRr) # reset gate
    i = sigm.(Wi' * x .+ Ri' * vec(h_in) .+ bWi .+ bRi) # input gate
    n = tanh.(Wn' * x .+ r .* (Rn' * vec(h_in) .+ bRn) .+ bWn) # new gate
    h_out = (1 .- i) .* n .+ i .* vec(h_in)
    return h_out
end

my_gru(W, x, h) = my_gru(W..., x, h)


# Compare both
x = randn(Float32, nX)
h = randn(Float32, nH, 1, 1)

knet_gru.h = h #set starting state for Knet RNN
res1 = knet_gru(x)
res2 = my_gru(rnn_params, x, h)
print("Difference between Knet and my_gru: ")
println(res1-res2)