I am trying to implement an autoencoder in Lux, inspired by the excellent one that I found in githun.com/lkulowski/LSTM_encoder_decoder, coded with PyTorch.
(diagram taken from the link)
Model overview
The encoder is a simple recursive LSTM (I used RNN for simplicity), that can be easily implemented with
encoder = Recurrence(RNN(1=>15)) # hidden_dim = 15
The decoder is slightly more complicated:
- input: the last hidden state and the last time step (\vec h_\tau,x^{(\tau)})
- recurrent structure: Generates outputs \hat y by passing each hidden state through a dense layer (not depicted in the original diagram). Hidden states are built from previous predictions:
(Note: the input at each step is the previous prediction, after the dense pass).
Custom decoder layer implementation
I am trying to implement this in Lux, and I could not find any custom way to do it, therefore I am trying to implement a custom layer.
struct Decoder{R,T} <: LuxCore.AbstractLuxContainerLayer{(:cell, :dense)}
cell::R
dense::T
end
# Initializer
function Decoder(in_dim, hidden_dim, out_dim)
return Decoder(RNNCell(in_dim=>hidden_dim), Dense(hidden_dim=>out_dim))
end
To implement the forward pass, a key challenge is efficiently accumulating the subsequent outputs \hat y^{(1)},\hat y^{(2)},\hat y^{(3)},… during the decoding phase.
Since Zygote.jl
does not support array mutation, it is not possible to initialize an empty vector ([]
) and append outputs within the loop, nor can we preallocate an array and fill it progressively. This limitation complicates the design of a stateful decoder, as shown in my current approach:
# Forward pass
function (net::Decoder)(ydata::AbstractArray{T,3}, (h_last,x_last), ps::NamedTuple, st::NamedTuple) where T
function f_recur((output, (hₜ,ŷₜ), s), x_t)
(_,(hₜ₊₁,)), st_cell = net.cell((ŷₜ,(hₜ,)), ps.cell, s.cell)
ŷₜ₊₁, st_dense = net.dense(hₜ₊₁ , ps.dense, s.dense)
# Stack output
out = (output...,ŷₜ₊₁)
# Update carry
carry = (hₜ₊₁,ŷₜ₊₁)
return out, carry, (cell=st_cell, dense=st_dense)
end
# We do recursion by hand because foldl gives me an error caused by subarray types
buffer = ((), (h_last,x_last), st)
for x_t in Lux.eachslice(ydata,dims=2)
buffer = f_recur(buffer, x_t)
end
# unpacking for clarity
(out, memory, state) = buffer
# stacking results to compute loss
out = hcat([reshape(y,size(y,1),1,size(y,2)) for y in out]...)
return out, state
end
Problem: Gradient Issues with Zygote
However, this approach fails with Zygote
when calling the gradient function.
For a minimal example of usage
rnx = Xoshiro(1994)
model = Decoder(1,15,1)
pars,state = Lux.setup(rng,model)
h_last = randn(rng, Float32, 15, 2)
x_last = randn(rng, Float32, 1, 2)
encoded_state = (h_last,x_last)
y_test = randn(rng, Float32, 1, 5, 2)
The error can be reproduced by running
train_state = Training.TrainState(model, pars, state, Adam(3.0f-4))
const lossMSE = MSELoss()
generic_loss(_model, _ps, _st, (y,enc_state)) = begin
ŷ,st = _model(y,enc_state, _ps, _st)
return lossMSE(ŷ,y), st, 0
end
Training.compute_gradients(
AutoZygote(), generic_loss, (y_test,encoded_state), train_state
)
or, directly with the pullback
loss(p,s) = begin
ŷ,st = model(y_test,encoded_state, p, s)
lossMSE(ŷ,y_test), st
end
(loss_eval, state, ), loss_back = pullback(loss,pars,state)
grad, _ = loss_back((1.,nothing))
The same issues occur when using another AD backend, like AutoEnzyme()
.
Question
How can I efficiently accumulate outputs in the forward pass without breaking the differentiability with Zygote
? Are there better design patterns for this type of recurrent decoding in Lux?