For Knet.jl connoisseurs : problem with WTYPE error

I am trying to build a RNN with self-attention mechanism using knet (maybe I should have sticked to python …) and after fixing all sorts of puzzling error I am faced with the following one :

AssertionError: r.c == nothing || (r.c == 0 || vec(value(r.c)) isa WTYPE && (ndims(r.c) <= 3 && (size(r.c, 1), size(r.c, 2), size(r.c, 3)) == HSIZE))

here is there code :

EPOCHS=10          # Number of training epochs
BATCHSIZE=5      # Number of instances in a minibatch
HIDDENRNN=30     # Hidden layer size
HIDDENATT=20
ATTENTIONHIDDEN=10
INPUTDIM = 3
MAXLEN=300        # maximum size of the word sequence, pad shorter sequences, truncate longer ones
VELSIZE=127   # maximum vocabulary size, keep the most frequent 30K, map the rest to UNK token
DROPOUT=0.3       # Dropout rate
LR=0.001          # Learning rate
BETA_1=0.9        # Adam optimization parameter
BETA_2=0.999      # Adam optimization parameter
EPS=1e-08         # Adam optimization parameter

struct velocity_RNN
        rnn
        output
        pdrop
end
velocity_RNN(input::Int, hidden::Int, output::Int; pdrop = 0.2) = velocity_RNN(RNN(input, hidden, dataType = Float64), param(output, hidden), pdrop)

function (vel_rnn::velocity_RNN)(input)
        hidden = vel_rnn.rnn(input)
        hidden = dropout(hidden, vel_rnn.pdrop)
        return vel_rnn.output * mat(hidden, dims = 1)
end

mutable struct attention_layer
         w1
         w2
         b
         f
end
attention_layer(i::Int, h::Int, f=tanh) = attention_layer(param(h,i), param(1,h), param0(h), f)
#takes as input ALL the previous hidden states to allow batch computation.
function (attention::attention_layer)(current_h, previous_hs)
        #repeating h in order to make it # dimensional vector to be added to the past values.
        repeated_h = [ch for ch in current_h[:,1,1], d in 1:1, l in 1:size(previous_hs)[3]]
        concatenated_hs = vcat(repeated_h, previous_hs)
        scores = attention.w2 * attention.f.(attention.w1 * mat(concatenated_hs) .+ attention.b) #might have to use "mat()" and or transpose.
        weights = softmax(scores)
        h = [previous_hs[i,j,k] * weights[k] for i in 1:size(previous_hs)[1], j in 1:1, k in 1:length(weights)]
        context_vector = [sum(h[i,1,:]) for i in 1:size(h)[1], j in 1:1, k in 1:1]
        return weights, context_vector
end

mutable struct attention_RNN
        seq_len
        rnn
        attention
end
attention_RNN(len::Int, input::Int, hidden_rnn::Int, hidden_attention::Int, output::Int) = attention_RNN(len, velocity_RNN(input, hidden_rnn, output), attention_layer(2*hidden_rnn, hidden_attention))

function (arnn::attention_RNN)(x)
        input = (x[:,:,i:i] for i in 1:size(x)[3])
        h = param(zeros(arnn.rnn.rnn.hiddenSize, 1, arnn.seq_len))
        outputs = param(zeros(size(arnn.rnn.output)[1], 1, arnn.seq_len))
        print(size(outputs))
        weights = []
        for (index, value) in enumerate(input)
                w, context = arnn.attention(h[:,:,index:index], h[:,:,1:index])
                arnn.rnn.rnn.c = context 
                arnn.rnn.rnn.h = h[:,:,index:index]
                print(typeof(arnn.rnn(value)))
                outputs[:,:,index] .= arnn.rnn(value)
                if index < arnn.seq_len
                        h[:,:,index+1:index+1] .= arnn.rnn.rnn.h
                else
                        break
                end
                push!(weights, w)
        end
        return outputs, weights
end

function NLLloss(scores::Param{Array{Float64,3}}, y)
        keeped_elements = findall(x -> x != -100, vec(y))
        expscores = exp.(scores)
        probabilities = expscores ./ sum(expscores, dims=1)
        answerprobs = (probabilities[y[i], 1, i] for i in 1:length(y) if y[i] != -100)
        return mean(-log.(answerprobs))
end

#functions that define behavior of model when input and truth are given in.
(vel_rnn::velocity_RNN)(x, y) = NLLloss(vel_rnn(x), y)
(arnn::attention_RNN)(x, y) = NLLloss(arnn(x)[1],y)

x, y = load("..\\src\\RNN_data.jld")
x_trn, y_trn = Array{Float64,3}[], Array{Int64,2}[]
for (index, data) in enumerate(x[2])
        push!(x_trn,reshape(data,(size(data)[1],1,size(data)[2])))
end
for (index, data) in enumerate(y[2])
        push!(y_trn,reshape(data,(size(data)[1],1)))
end

#defining custom minibatcher
idxs = collect(1:length(x_trn)) #the index to be shuffled for minibatching
coribatch() = ((x_trn[i], y_trn[i]) for i in Random.shuffle(idxs)[1:BATCHSIZE])

#the function defining the behavior of the model when presented a custombatch
(vel_rnn::velocity_RNN)(batch, isbatch::Bool) = mean(vel_rnn(x,y) for (x,y) in batch)
(arnn::attention_RNN)(batch, isbatch::Bool) = mean(arnn(x,y) for (x,y) in batch)

function sgdupdate!(func, args; lr=0.1)
     fval = @diff func(args...)
     for param in params(fval)
        ∇param = grad(fval, param)
         param .-= lr * ∇param
     end
end

x,y = first(coribatch())
arnn = attention_RNN(MAXLEN, INPUTDIM, HIDDENRNN, HIDDENATT, VELSIZE)
sgdupdate!(arnn,(x,y))

The error stacktrace reads :

Stacktrace:
 [1] #call#600(::Nothing, ::RNN, ::Array{Float64,3}) at C:\Users\cnelias\.julia\packages\Knet\LjPts\src\rnn.jl:196
 [2] (::RNN)(::Array{Float64,3}) at C:\Users\cnelias\.julia\packages\Knet\LjPts\src\rnn.jl:190
 [3] (::velocity_RNN)(::Array{Float64,3}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:48
 [4] (::attention_RNN)(::Array{Float64,3}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:97
 [5] (::attention_RNN)(::Array{Float64,3}, ::Array{Int64,2}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:127
 [6] (::getfield(Main, Symbol("##1971#1972")){attention_RNN,Tuple{Array{Float64,3},Array{Int64,2}}})() at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:205
 [7] #differentiate#3(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(AutoGrad.differentiate), ::Function) at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:144
 [8] differentiate at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:135 [inlined]
 [9] #sgdupdate!#1970(::Float64, ::typeof(sgdupdate!), ::attention_RNN, ::Tuple{Array{Float64,3},Array{Int64,2}}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:148
 [10] sgdupdate!(::attention_RNN, ::Tuple{Array{Float64,3},Array{Int64,2}}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:148
 [11] top-level scope at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:162
 [12] include_string(::Module, ::String, ::String) at .\loading.jl:1064
 [13] (::getfield(Atom, Symbol("##139#144")){String,String,Module})() at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:138
 [14] withpath(::getfield(Atom, Symbol("##139#144")){String,String,Module}, ::String) at C:\Users\cnelias\.julia\packages\CodeTools\sf1Tz\src\utils.jl:30
 [15] withpath at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:47 [inlined]
 [16] #138 at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:135 [inlined]
 [17] with_logstate(::getfield(Atom, Symbol("##138#143")){String,String,Module}, ::Base.CoreLogging.LogState) at .\logging.jl:395
 [18] with_logger at .\logging.jl:491 [inlined]
 [19] #137 at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:134 [inlined]
 [20] hideprompt(::getfield(Atom, Symbol("##137#142")){String,String,Module}) at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\repl.jl:85
 [21] macro expansion at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:133 [inlined]
 [22] macro expansion at C:\Users\cnelias\.julia\packages\Media\ItEPc\src\dynamic.jl:24 [inlined]
 [23] (::getfield(Atom, Symbol("##136#141")))(::Dict{String,Any}) at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\eval.jl:122
 [24] handlemsg(::Dict{String,Any}, ::Dict{String,Any}) at C:\Users\cnelias\.julia\packages\Atom\lBERI\src\comm.jl:164
 [25] (::getfield(Atom, Symbol("##19#21")){Array{Any,1}})() at .\task.jl:268
ERROR: LoadError: AssertionError: r.c == nothing || (r.c == 0 || vec(value(r.c)) isa WTYPE && (ndims(r.c) <= 3 && (size(r.c, 1), size(r.c, 2), size(r.c, 3)) == HSIZE))
Stacktrace:
 [1] #differentiate#3(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(AutoGrad.differentiate), ::Function) at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:148
 [2] differentiate at C:\Users\cnelias\.julia\packages\AutoGrad\pTNVv\src\core.jl:135 [inlined]
 [3] #sgdupdate!#1970(::Float64, ::typeof(sgdupdate!), ::attention_RNN, ::Tuple{Array{Float64,3},Array{Int64,2}}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:148
 [4] sgdupdate!(::attention_RNN, ::Tuple{Array{Float64,3},Array{Int64,2}}) at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:148
 [5] top-level scope at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:162
in expression starting at C:\Users\cnelias\Desktop\PHD\Swing project\code\script\RNN_attention.jl:162

I think this is because the context vector provided to the RNN has to be of the same type and dimension as the internal parameters of the RNN (otherwise returns a WTYPE error) but I have already checked and all the dimension should actually be correct so I don;t understand what is going on. Moreover, the forward pass (without @diff works perfectly.
Has anyone any idea of what is going on here ? I have been bashing my head on this for so long that I am thinking of completely quitting machine learning in Julia …