I am trying to build a RNN with self-attention mechanism using knet (maybe I should have sticked to python …) and after fixing all sorts of puzzling error I am faced with the following one :
AssertionError: r.c == nothing || (r.c == 0 || vec(value(r.c)) isa WTYPE && (ndims(r.c) <= 3 && (size(r.c, 1), size(r.c, 2), size(r.c, 3)) == HSIZE))
here is there code :
EPOCHS=10 # Number of training epochs
BATCHSIZE=5 # Number of instances in a minibatch
HIDDENRNN=30 # Hidden layer size
HIDDENATT=20
ATTENTIONHIDDEN=10
INPUTDIM = 3
MAXLEN=300 # maximum size of the word sequence, pad shorter sequences, truncate longer ones
VELSIZE=127 # maximum vocabulary size, keep the most frequent 30K, map the rest to UNK token
DROPOUT=0.3 # Dropout rate
LR=0.001 # Learning rate
BETA_1=0.9 # Adam optimization parameter
BETA_2=0.999 # Adam optimization parameter
EPS=1e-08 # Adam optimization parameter
struct velocity_RNN
rnn
output
pdrop
end
velocity_RNN(input::Int, hidden::Int, output::Int; pdrop = 0.2) = velocity_RNN(RNN(input, hidden, dataType = Float64), param(output, hidden), pdrop)
function (vel_rnn::velocity_RNN)(input)
hidden = vel_rnn.rnn(input)
hidden = dropout(hidden, vel_rnn.pdrop)
return vel_rnn.output * mat(hidden, dims = 1)
end
mutable struct attention_layer
w1
w2
b
f
end
attention_layer(i::Int, h::Int, f=tanh) = attention_layer(param(h,i), param(1,h), param0(h), f)
#takes as input ALL the previous hidden states to allow batch computation.
function (attention::attention_layer)(current_h, previous_hs)
#repeating h in order to make it # dimensional vector to be added to the past values.
repeated_h = [ch for ch in current_h[:,1,1], d in 1:1, l in 1:size(previous_hs)[3]]
concatenated_hs = vcat(repeated_h, previous_hs)
scores = attention.w2 * attention.f.(attention.w1 * mat(concatenated_hs) .+ attention.b) #might have to use "mat()" and or transpose.
weights = softmax(scores)
h = [previous_hs[i,j,k] * weights[k] for i in 1:size(previous_hs)[1], j in 1:1, k in 1:length(weights)]
context_vector = [sum(h[i,1,:]) for i in 1:size(h)[1], j in 1:1, k in 1:1]
return weights, context_vector
end
mutable struct attention_RNN
seq_len
rnn
attention
end
attention_RNN(len::Int, input::Int, hidden_rnn::Int, hidden_attention::Int, output::Int) = attention_RNN(len, velocity_RNN(input, hidden_rnn, output), attention_layer(2*hidden_rnn, hidden_attention))
function (arnn::attention_RNN)(x)
input = (x[:,:,i:i] for i in 1:size(x)[3])
h = param(zeros(arnn.rnn.rnn.hiddenSize, 1, arnn.seq_len))
outputs = param(zeros(size(arnn.rnn.output)[1], 1, arnn.seq_len))
print(size(outputs))
weights = []
for (index, value) in enumerate(input)
w, context = arnn.attention(h[:,:,index:index], h[:,:,1:index])
arnn.rnn.rnn.c = context
arnn.rnn.rnn.h = h[:,:,index:index]
print(typeof(arnn.rnn(value)))
outputs[:,:,index] .= arnn.rnn(value)
if index < arnn.seq_len
h[:,:,index+1:index+1] .= arnn.rnn.rnn.h
else
break
end
push!(weights, w)
end
return outputs, weights
end
function NLLloss(scores::Param{Array{Float64,3}}, y)
keeped_elements = findall(x -> x != -100, vec(y))
expscores = exp.(scores)
probabilities = expscores ./ sum(expscores, dims=1)
answerprobs = (probabilities[y[i], 1, i] for i in 1:length(y) if y[i] != -100)
return mean(-log.(answerprobs))
end
#functions that define behavior of model when input and truth are given in.
(vel_rnn::velocity_RNN)(x, y) = NLLloss(vel_rnn(x), y)
(arnn::attention_RNN)(x, y) = NLLloss(arnn(x)[1],y)
x, y = load("..\\src\\RNN_data.jld")
x_trn, y_trn = Array{Float64,3}[], Array{Int64,2}[]
for (index, data) in enumerate(x[2])
push!(x_trn,reshape(data,(size(data)[1],1,size(data)[2])))
end
for (index, data) in enumerate(y[2])
push!(y_trn,reshape(data,(size(data)[1],1)))
end
#defining custom minibatcher
idxs = collect(1:length(x_trn)) #the index to be shuffled for minibatching
coribatch() = ((x_trn[i], y_trn[i]) for i in Random.shuffle(idxs)[1:BATCHSIZE])
#the function defining the behavior of the model when presented a custombatch
(vel_rnn::velocity_RNN)(batch, isbatch::Bool) = mean(vel_rnn(x,y) for (x,y) in batch)
(arnn::attention_RNN)(batch, isbatch::Bool) = mean(arnn(x,y) for (x,y) in batch)
function sgdupdate!(func, args; lr=0.1)
fval = @diff func(args...)
for param in params(fval)
∇param = grad(fval, param)
param .-= lr * ∇param
end
end
x,y = first(coribatch())
arnn = attention_RNN(MAXLEN, INPUTDIM, HIDDENRNN, HIDDENATT, VELSIZE)
sgdupdate!(arnn,(x,y))
The error stacktrace reads :
Stacktrace:
[1] #call#600(::Nothing, ::RNN, ::Array{Float64,3}) at C:\Users\cnelias\.julia\packages\Knet\LjPts\src\rnn.jl:196
[2] (::RNN)(::Array{Float64,3}) at C:\Users\cnelias\.julia\packages\Knet\LjPts\src\rnn.jl:190
[3] (::velocity_RNN)(::Array{Float64,3}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:48
[4] (::attention_RNN)(::Array{Float64,3}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:97
[5] (::attention_RNN)(::Array{Float64,3}, ::Array{Int64,2}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:127
[6] (::getfield(Main, Symbol("##1971#1972")){attention_RNN,Tuple{Array{Float64,3},Array{Int64,2}}})() at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:205
[7] #differentiate#3(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(AutoGrad.differentiate), ::Function) at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:144
[8] differentiate at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:135 [inlined]
[9] #sgdupdate!#1970(::Float64, ::typeof(sgdupdate!), ::attention_RNN, ::Tuple{Array{Float64,3},Array{Int64,2}}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:148
[10] sgdupdate!(::attention_RNN, ::Tuple{Array{Float64,3},Array{Int64,2}}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:148
[11] top-level scope at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:162
[12] include_string(::Module, ::String, ::String) at .\loading.jl:1064
[13] (::getfield(Atom, Symbol("##139#144")){String,String,Module})() at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:138
[14] withpath(::getfield(Atom, Symbol("##139#144")){String,String,Module}, ::String) at C:\Users\cnelias\.julia\packages\CodeTools\sf1Tz\src\utils.jl:30
[15] withpath at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:47 [inlined]
[16] #138 at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:135 [inlined]
[17] with_logstate(::getfield(Atom, Symbol("##138#143")){String,String,Module}, ::Base.CoreLogging.LogState) at .\logging.jl:395
[18] with_logger at .\logging.jl:491 [inlined]
[19] #137 at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:134 [inlined]
[20] hideprompt(::getfield(Atom, Symbol("##137#142")){String,String,Module}) at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\repl.jl:85
[21] macro expansion at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:133 [inlined]
[22] macro expansion at C:\Users\cnelias\.julia\packages\Media\ItEPc\src\dynamic.jl:24 [inlined]
[23] (::getfield(Atom, Symbol("##136#141")))(::Dict{String,Any}) at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:122
[24] handlemsg(::Dict{String,Any}, ::Dict{String,Any}) at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\comm.jl:164
[25] (::getfield(Atom, Symbol("##19#21")){Array{Any,1}})() at .\task.jl:268
ERROR: LoadError: AssertionError: r.c == nothing || (r.c == 0 || vec(value(r.c)) isa WTYPE && (ndims(r.c) <= 3 && (size(r.c, 1), size(r.c, 2), size(r.c, 3)) == HSIZE))
Stacktrace:
[1] #differentiate#3(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(AutoGrad.differentiate), ::Function) at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:148
[2] differentiate at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:135 [inlined]
[3] #sgdupdate!#1970(::Float64, ::typeof(sgdupdate!), ::attention_RNN, ::Tuple{Array{Float64,3},Array{Int64,2}}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:148
[4] sgdupdate!(::attention_RNN, ::Tuple{Array{Float64,3},Array{Int64,2}}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:148
[5] top-level scope at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:162
in expression starting at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:162
I think this is because the context vector provided to the RNN has to be of the same type and dimension as the internal parameters of the RNN (otherwise returns a WTYPE error) but I have already checked and all the dimension should actually be correct so I don;t understand what is going on. Moreover, the forward pass (without @diff
works perfectly.
Has anyone any idea of what is going on here ? I have been bashing my head on this for so long that I am thinking of completely quitting machine learning in Julia …