Flux params() not working for composed layers

I’m trying to create a multi-headed attention layer, which is composed several “heads”. Each head has its own parameters (three matrices) and the “wrapper” layer has one output parameter (matrix). This is the code to do the multiheaded attention.

using Flux
using Flux: glorot_uniform, @treelike
using Flux.Tracker: param
struct Head_Attn{T}
	W_k::T
	W_q::T
	W_v::T
	function Head_Attn(d_v::Integer,d_model::Integer)
		W_k,W_q,W_v = [param(glorot_uniform(d_v,d_model)) for i in 1:3]
		new{typeof(W_k)}(W_k,W_q,W_v)
	end
end

function (h_attn::Head_Attn)(query::AbstractArray{<:Real,2},target::AbstractArray{<:Real,2})
	W_k,W_q,W_v = (h_attn.W_k, h_attn.W_q, h_attn.W_v)
	K_i = W_k * target
	V_i = W_v * target
	Q_i = W_q * query

	d_k = size(W_k,1)

	V_i * softmax((K_i' * Q_i)/d_k)
end

@treelike Head_Attn

struct Multi_Head_Attn{T,W}
	heads::Vector{T}
	W_o::W
	function Multi_Head_Attn(d_model::Integer,num_heads::Integer) 
		d_v::Integer = ceil(d_model/num_heads)
		heads = [Head_Attn(d_v,d_model) for _ in 1:num_heads]
		W_o = param(glorot_uniform(d_model,d_v*num_heads))
		new{eltype(heads),typeof(W_o)}(heads,W_o)
	end
end

function (attn::Multi_Head_Attn)(query::AbstractArray{<:Real,2},target::AbstractArray{<:Real,2})
	heads = attn.heads
	W_o = attn.W_o
	
	conc_heads = vcat([head(query,target) for head in heads]...)
	return W_o * conc_heads
end

@treelike Multi_Head_Attn

Then when creating one of this objects and getting its parameters (for later optimizer creation), the layer should return all parameters, not only the ones from the wrapper (Multi_Head_Attn).

attn = Multi_Head_Attn(5,2)
for p in params(attn)
    println(size(p))
end
println("-------")
println("e.g. these parameters are ignored")
for p in params(attn.heads[1])
    println(size(p))
end

(5, 6)

e.g. these parameters are ignored
(3, 5)
(3, 5)
(3, 5)

I’m not entirely sure that this is right (I’m not all that familiar with multi-headed attention, to be honest), but I got something that seems like it might be on the right track by replacing @treelike Multi_Head_Attn with this:

import Flux.children, Flux.mapchildren
Flux.children(attn::Multi_Head_Attn) = (attn.heads..., attn.W_o)
Flux.mapchildren(f, attn::Multi_Head_Attn) = Multi_Head_Attn(f.(attn.heads), f(attn.W_o))

Does that work?

1 Like

yes it should actually work, although I decided to change the type of heads for a tuple and then it seemed to work.

1 Like