Modify existing Lux models

Hi @avikpal I think my question is relevant to this topic so I’d like to ask it here to see whether what I’ve done is reasonable. My aim is to implement a RNN with return_sequence=True and its input and output should be preprocessed and post-processed by Dense layers, like the simple model below:

function GRU_Seq2Seq(nfeatures; gru_out_dim=GRU_OUT, dense_act=DENSE_ACT)
    dense_in1 = Lux.Dense(nfeatures=>gru_out_dim*2, dense_act)
    drop1 = Lux.Dropout(0.5)
    dense_in2 = Lux.Dense(gru_out_dim*2=>gru_out_dim, dense_act)
    gru = Lux.Recurrence(Lux.GRUCell(gru_out_dim=>gru_out_dim); return_sequence=true)
    stack_layer = StackLayer(2)
    dense_out1 = Lux.Dense(gru_out_dim=>16, dense_act)
    drop2 = Lux.Dropout(0.2)
    dense_out2 = Lux.Dense(16=>1, identity)
    return Lux.Chain(dense_in1, drop1, dense_in2, gru, stack_layer, dense_out1, drop2, dense_out2)
end

where the StackLayer is defined to make the RNN output suitable for Dense layers:

struct StackLayer <: Lux.AbstractExplicitLayer
    dims::Int
end

@inline function (s::StackLayer)(x::AbstractArray, ps, st::NamedTuple)
    return stack(x, dims=s.dims), st
end

I’d like to know is there any idiomatic way to define such Seq2Seq RNN model without hacking my own StackLayer. Do I miss something from the Lux.jl doc? I feel it is a natural thing to do a post processing by Dense layers.

Moreover, to test the performance of my seq2seq model, I also implement a seq2last model which only return the last value like following:

function GRU_Seq2Last(nfeatures; gru_out_dim=GRU_OUT, dense_act=DENSE_ACT)
    dense_in1 = Lux.Dense(nfeatures=>gru_out_dim*2, dense_act)
    drop1 = Lux.Dropout(0.5)
    dense_in2 = Lux.Dense(gru_out_dim*2=>gru_out_dim, dense_act)
    gru = Lux.Recurrence(Lux.GRUCell(gru_out_dim=>gru_out_dim))
    dense_out1 = Lux.Dense(gru_out_dim=>16, dense_act)
    drop2 = Lux.Dropout(0.2)
    dense_out2 = Lux.Dense(16=>1, identity)
    return Lux.Chain(dense_in1, drop1, dense_in2, gru, dense_out1, drop2, dense_out2)
end

Now that the GRU_Seq2Last is trained, I’d like use it to evaluate the test data but return a full sequence instead of the last value as in the training phase. Currently, I use a really ugly hack to transform GRU_Seq2Last parameters and states to be usable by GRU_Seq2Seq and use this model to predict. The codes are:

function seq2last_to_seq2seq(ps, st)
    ps_out = (layer_1=ps.layer_1, layer_2=ps.layer_2,
              layer_3=ps.layer_3, layer_4=ps.layer_4,
              layer_5=NamedTuple(),
              layer_6=ps.layer_5, layer_7=ps.layer_6,
              layer_8=ps.layer_7)
    st_out = (layer_1=st.layer_1, layer_2=st.layer_2,
              layer_3=st.layer_3, layer_4=st.layer_4,
              layer_5=NamedTuple(),
              layer_6=st.layer_5, layer_7=st.layer_6,
              layer_8=st.layer_7)
    return ps_out, st_out
end

# ps_seq2last, st_seq2last is trained by GRU_Seq2Last
ps_seq2seq, st_seq2seq = seq2last_to_seq2seq(ps_seq2last, ps_seq2last)
yp_seq2seq = GRU_Seq2Seq(40)(X, ps_seq2seq, st_seq2seq)

The Setfield.jl seems can’t help here since I want to insert a layer (the StackLayer) instead of replace it.