I’m trying to create a multi-headed attention layer, which is composed several “heads”. Each head has its own parameters (three matrices) and the “wrapper” layer has one output parameter (matrix). This is the code to do the multiheaded attention.
using Flux
using Flux: glorot_uniform, @treelike
using Flux.Tracker: param
struct Head_Attn{T}
W_k::T
W_q::T
W_v::T
function Head_Attn(d_v::Integer,d_model::Integer)
W_k,W_q,W_v = [param(glorot_uniform(d_v,d_model)) for i in 1:3]
new{typeof(W_k)}(W_k,W_q,W_v)
end
end
function (h_attn::Head_Attn)(query::AbstractArray{<:Real,2},target::AbstractArray{<:Real,2})
W_k,W_q,W_v = (h_attn.W_k, h_attn.W_q, h_attn.W_v)
K_i = W_k * target
V_i = W_v * target
Q_i = W_q * query
d_k = size(W_k,1)
V_i * softmax((K_i' * Q_i)/d_k)
end
@treelike Head_Attn
struct Multi_Head_Attn{T,W}
heads::Vector{T}
W_o::W
function Multi_Head_Attn(d_model::Integer,num_heads::Integer)
d_v::Integer = ceil(d_model/num_heads)
heads = [Head_Attn(d_v,d_model) for _ in 1:num_heads]
W_o = param(glorot_uniform(d_model,d_v*num_heads))
new{eltype(heads),typeof(W_o)}(heads,W_o)
end
end
function (attn::Multi_Head_Attn)(query::AbstractArray{<:Real,2},target::AbstractArray{<:Real,2})
heads = attn.heads
W_o = attn.W_o
conc_heads = vcat([head(query,target) for head in heads]...)
return W_o * conc_heads
end
@treelike Multi_Head_Attn
Then when creating one of this objects and getting its parameters (for later optimizer creation), the layer should return all parameters, not only the ones from the wrapper (Multi_Head_Attn).
attn = Multi_Head_Attn(5,2)
for p in params(attn)
println(size(p))
end
println("-------")
println("e.g. these parameters are ignored")
for p in params(attn.heads[1])
println(size(p))
end
(5, 6)
e.g. these parameters are ignored
(3, 5)
(3, 5)
(3, 5)