Issue with backprop in a custom recurrent `Lux` layer

I am trying to implement an autoencoder in Lux, inspired by the excellent one that I found in githun.com/lkulowski/LSTM_encoder_decoder, coded with PyTorch.


(diagram taken from the link)

Model overview

The encoder is a simple recursive LSTM (I used RNN for simplicity), that can be easily implemented with

encoder = Recurrence(RNN(1=>15)) # hidden_dim = 15

The decoder is slightly more complicated:

  • input: the last hidden state and the last time step (\vec h_\tau,x^{(\tau)})
  • recurrent structure: Generates outputs \hat y by passing each hidden state through a dense layer (not depicted in the original diagram). Hidden states are built from previous predictions:
\hat y^{(1)} = \text{Dense}(\text{RNN}((\vec h_\tau,x^{(\tau)}))) \\ \hat y^{(2)} = \text{Dense}(\text{RNN}((\vec h_{\tau+1},\hat y^{(1)}))) \\ \cdots

(Note: the input at each step is the previous prediction, after the dense pass).

Custom decoder layer implementation

I am trying to implement this in Lux, and I could not find any custom way to do it, therefore I am trying to implement a custom layer.

struct Decoder{R,T} <: LuxCore.AbstractLuxContainerLayer{(:cell, :dense)}
    cell::R
    dense::T
end

# Initializer
function Decoder(in_dim, hidden_dim, out_dim)
    return Decoder(RNNCell(in_dim=>hidden_dim), Dense(hidden_dim=>out_dim))
end

To implement the forward pass, a key challenge is efficiently accumulating the subsequent outputs \hat y^{(1)},\hat y^{(2)},\hat y^{(3)},… during the decoding phase.

Since Zygote.jl does not support array mutation, it is not possible to initialize an empty vector ([]) and append outputs within the loop, nor can we preallocate an array and fill it progressively. This limitation complicates the design of a stateful decoder, as shown in my current approach:

# Forward pass
function (net::Decoder)(ydata::AbstractArray{T,3}, (h_last,x_last), ps::NamedTuple, st::NamedTuple) where T
    function f_recur((output, (hₜ,ŷₜ), s), x_t) 
        (_,(hₜ₊₁,)), st_cell = net.cell((ŷₜ,(hₜ,)), ps.cell, s.cell)
        ŷₜ₊₁, st_dense       = net.dense(hₜ₊₁    , ps.dense, s.dense)

        # Stack output
        out = (output...,ŷₜ₊₁)

        # Update carry
        carry = (hₜ₊₁,ŷₜ₊₁)

        return out, carry, (cell=st_cell, dense=st_dense)
    end

    # We do recursion by hand because foldl gives me an error caused by subarray types 
    buffer = ((), (h_last,x_last), st)
    for x_t in Lux.eachslice(ydata,dims=2)
        buffer = f_recur(buffer, x_t)
    end
    
    # unpacking for clarity
    (out, memory, state) = buffer
    
    # stacking results to compute loss
    out = hcat([reshape(y,size(y,1),1,size(y,2)) for y in out]...)

    return out, state
end

Problem: Gradient Issues with Zygote

However, this approach fails with Zygote when calling the gradient function.

For a minimal example of usage

rnx = Xoshiro(1994)

model = Decoder(1,15,1)
pars,state = Lux.setup(rng,model)

h_last = randn(rng, Float32, 15, 2)
x_last = randn(rng, Float32, 1, 2)
encoded_state = (h_last,x_last)

y_test = randn(rng, Float32, 1, 5, 2)

The error can be reproduced by running

train_state = Training.TrainState(model, pars, state, Adam(3.0f-4))
const lossMSE = MSELoss()
generic_loss(_model, _ps, _st, (y,enc_state)) = begin
    ŷ,st = _model(y,enc_state, _ps, _st)
    return lossMSE(ŷ,y), st, 0       
end
Training.compute_gradients(
    AutoZygote(), generic_loss, (y_test,encoded_state), train_state
)

or, directly with the pullback

loss(p,s) = begin
    ŷ,st = model(y_test,encoded_state, p, s)
    lossMSE(ŷ,y_test), st    
end

(loss_eval, state, ), loss_back = pullback(loss,pars,state)
grad, _ = loss_back((1.,nothing))

The same issues occur when using another AD backend, like AutoEnzyme().

Question

How can I efficiently accumulate outputs in the forward pass without breaking the differentiability with Zygote? Are there better design patterns for this type of recurrent decoding in Lux?

This sounds very suspicious, can you share the full MWE and stack traces for both backends?

For Zygote.jl

using Random, Plots
using Lux, Zygote, Optimisers
using Distributions, Statistics, IterTools, ProgressMeter

rng = Xoshiro(1994)

struct Decoder{R,T} <: LuxCore.AbstractLuxContainerLayer{(:cell, :dense)}
    cell::R
    dense::T
end

# Initializer
function Decoder(in_dim, hidden_dim, out_dim)
    return Decoder(RNNCell(in_dim=>hidden_dim), Dense(hidden_dim=>out_dim))
end

# Forward pass
function (net::Decoder)(ydata::AbstractArray{T,3}, (h_last,x_last), ps::NamedTuple, st::NamedTuple) where T
    function f_recur((output, (hₜ,ŷₜ), s), x_t) 
        (_,(hₜ₊₁,)), st_cell = net.cell((ŷₜ,(hₜ,)), ps.cell, s.cell)
        ŷₜ₊₁, st_dense       = net.dense(hₜ₊₁    , ps.dense, s.dense)

        # Stack output
        out = (output...,ŷₜ₊₁)

        # Update carry
        carry = (hₜ₊₁,ŷₜ₊₁)

        return out, carry, (cell=st_cell, dense=st_dense)
    end

    # We do recursion by hand because foldl gives me an error caused by subarray types 
    buffer = ((), (h_last,x_last), st)
    for x_t in Lux.eachslice(ydata,dims=2)
        buffer = f_recur(buffer, x_t)
    end
    
    # unpacking for clarity
    (out, memory, state) = buffer
    
    # stacking results to compute loss
    out = hcat([reshape(y,size(y,1),1,size(y,2)) for y in out]...)

    return out, state
end




model = Decoder(1,15,1)
pars,state = Lux.setup(rng,model)

h_last = randn(rng, Float32, 15, 2)
x_last = randn(rng, Float32, 1, 2)
encoded_state = (h_last,x_last)

y_test = randn(rng, Float32, 1, 5, 2)


train_state = Training.TrainState(model, pars, state, Adam(3.0f-4))
const lossMSE = MSELoss()
generic_loss(_model, _ps, _st, (y,enc_state)) = begin
    ŷ,st = _model(y,enc_state, _ps, _st)
    return lossMSE(ŷ,y), st, 0       
end
Training.compute_gradients(
    AutoZygote(), generic_loss, (y_test,encoded_state), train_state
)

produces the following error

ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::@NamedTuple{contents::Tuple{Matrix{Float32}}})
The function `+` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...)
   @ Base operators.jl:596
  +(::ChainRulesCore.ZeroTangent, ::Any)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_arithmetic.jl:99
  +(::Any, ::ChainRulesCore.NotImplemented)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_arithmetic.jl:25
  ...

Stacktrace:
 [1] accum(x::Base.RefValue{Any}, y::@NamedTuple{contents::Tuple{Matrix{Float32}}})
   @ Zygote ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:17
 [2] Decoder
   @ ~/code/julia/sketches/myODE/w_Lux/bug.jl:19 [inlined]
 [3] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Array{…}, Nothing})
   @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [4] generic_loss
   @ ~/code/julia/sketches/myODE/w_Lux/bug.jl:68 [inlined]
 [5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
   @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [6] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Tuple{Float32, Nothing, Nothing})
   @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:91
 [7] compute_gradients_impl(::AutoZygote, objective_function::typeof(generic_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
   @ LuxZygoteExt ~/.julia/packages/Lux/DHtyL/ext/LuxZygoteExt/training.jl:13
 [8] compute_gradients(ad::AutoZygote, obj_fn::typeof(generic_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
   @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:200
 [9] top-level scope
   @ ~/code/julia/sketches/myODE/w_Lux/bug.jl:71
Some type information was truncated. Use `show(err)` to see complete types.

if you call the loss function

generic_loss(model, pars, state, (y_test,encoded_state))

I don’t get any error. I don’t understand why I get this MethodError

For Enzyme.jl
same code, but importing Enzyme.jl and changing AutoZygote()->AutoEnzyme() in the last line

The error is

ERROR: Enzyme.Compiler.EnzymeRuntimeActivityError(Cstring(0x00000002c24d00e0))
Stacktrace:
  [1] Decoder (repeats 2 times)
    @ ~/code/julia/sketches/myODE/w_Lux/bug.jl:44
  [2] #4
    @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:256 [inlined]
  [3] augmented_julia__4_15891_inner_13wrap
    @ ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:0
  [4] macro expansion
    @ ~/.julia/packages/Enzyme/zQBPg/src/compiler.jl:5377 [inlined]
  [5] enzyme_call
    @ ~/.julia/packages/Enzyme/zQBPg/src/compiler.jl:4915 [inlined]
  [6] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/zQBPg/src/compiler.jl:4851 [inlined]
  [7] autodiff
    @ ~/.julia/packages/Enzyme/zQBPg/src/Enzyme.jl:396 [inlined]
  [8] compute_gradients_impl
    @ ~/.julia/packages/Lux/DHtyL/ext/LuxEnzymeExt/training.jl:8 [inlined]
  [9] compute_gradients(ad::AutoEnzyme{…}, obj_fn::typeof(generic_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:200
 [10] top-level scope
    @ ~/code/julia/sketches/myODE/w_Lux/bug.jl:71
Some type information was truncated. Use `show(err)` to see complete types.

What if you use Enzyme from within Reactant (see the Lux readme for an example: GitHub - LuxDL/Lux.jl: Elegant and Performant Deep Learning )