Flux custom model - feedback of the output layer to the input layer

Hi,

I´m trying to feedback the output of the output layer of a fully feedfoward neural network to the input layer of the network ?

I would like to create a similiar structure like in the illustration.

Detailed explanation: So imagine you have a technical system with the input u and the output y and you would like to train a nonlinear dynamic data based model to approximate the system behaviour. You already have excitated the system and you have gathered the informations in two time series for the input u and the output y. The q^{-1} in the illustration are unit delays, which means that you delay the time series about one discrete time step. Now I would like to feed the fully feed foward neural network i.e. with one dense hidden layer with the delayed time series of the input u and delayed time series of the model output \hat{y} (output of neural network). Therefore I would like to feedback the delayed output of the neural network to the input layer of the neural network.

I started to implement such a structure and I think it is correct (here a single input single output system with 3 states for the model output and one delay of input signal), but if I now call ps = Flux.params(model) the ps variable is empty. Does anybody knows why? Furthermore I would like to know, if there is a more elegant way to stack the layers also for custom layers, where the layers aren´t just directly connected or the feedback is included in a cell.

using Flux

mutable struct RNNCellNOE
    W
    b
    h
    σ
end

RNNCellNOE(in::Integer, out::Integer, states::Integer, σ = tanh) =
  RNNCellNOE(randn(out, in), randn(out), zeros(states), σ)

function (m::RNNCellNOE)(x)
    σ, W, b, h = m.σ, m.W, m.b, m.h
    for i in 2:length(h)
        h[i-1] = h[i]
    end
    h[end] = σ.(W*x .+ b)[1]
    return h[end], h
end

layer1 = Dense(4, 10, σ)
rnn_noe = RNNCellNOE(10, 1, 3)
model(x) = rnn_noe(layer1(vcat(x, rnn_noe.h)))
ps = Flux.params(model)

u = vcat([0], rand(99, 1))'
y = model.(u)
2 Likes

The reason why params does not return anything is that it needs Flux.trainable(l::RNNCellNOE) to return the parameters you want to train. In general you need to implement Flux.functor(l::RNNCellNOE) for other utilities such as gpu to work. Here is the relevant section from the docs: Advanced Model Building · Flux

I think that your RNNCellNOE is not differentiable by Zygote (flux AD mechanism) because it does array mutation (it changes the contents of the array h).

I think you should be able to achieve what you want using Flux.Recur. Here is a simple example:

julia> struct DummyCell end

julia> function (::DummyCell)(qs, x)
       q2,q1 = qs
       @show q2
       @show q1
       output = x + 1
       return (q1, output), output
       end

julia> rr = Flux.Recur(DummyCell(), (0, 0))
Recur(DummyCell())

julia> rr(0)
q2 = 0
q1 = 0
1

julia> rr(1)
q2 = 0
q1 = 1
2

julia> rr(2)
q2 = 1
q1 = 2
3

Your real cell will then contain a Dense layer which it uses to compute the output where the input is a concatentation of x, q1 and q2. Don’t forget to implement functor :slight_smile:

1 Like

I think you´re right. I managed, that params() returns the parameters (ps = Flux.params(layer1, rnn_noe.W, rnn_noe.b)), but I can´t train it despite I can evaluate the model getting an error message

ERROR: Mutating arrays is not supported

Honestly I have to admit, that I don´t really understand your simple example :sweat_smile:
Which functionality has the second argument of Flux.Recur() and why you can call rr(1) with just one argument and also just get one return value?

The array mutation error is because you change the contents of h which is an array in your forward pass.

As for the example, here is the relevant portion of the docs: Recurrence · Flux

tl;dr version is that Flux.Recur manages the hidden state. The second argument is the initial value for qs.

I really recommend looking at the source code of flux when questions like this appear. It is for the most part very readable. I think both Juno and VS code support CTRL+clicking functions (and other variables) to open their definition.

Here is anyways a short rundown:

julia> cell = DummyCell()
DummyCell()

# Initial values of the delayed output
julia> qs = (0,0)
(0, 0)

julia> qs, y = cell(qs, 0)
q2 = 0
q1 = 0
((0, 1), 1)

julia> qs, y = cell(qs, 1)
q2 = 0
q1 = 1
((1, 2), 2)

julia> qs, y = cell(qs, 2)
q2 = 1
q1 = 2
((2, 3), 3)

# Idea: Why not just wrap cell in a closure which handles the hidden state for us?
julia> function handle_state_for_me_please(cell, qs =(0,0))
           return function(x)
               qs, y = cell(qs, x)
           return y
           end
       end
handle_state_for_me_please (generic function with 2 methods)

julia> rr = handle_state_for_me_please(cell, (0,0))
#329 (generic function with 1 method)

julia> rr(0)
q2 = 0
q1 = 0
1

julia> rr(1)
q2 = 0
q1 = 1
2

julia> rr(2)
q2 = 1
q1 = 2
3

# Now we have reimplemented Flux.Recur
1 Like

Thanks again for your help. You´re right. Now I have played around a little bit.

using Flux

mutable struct RNNCellNOE
   W
   b
   σ
end

RNNCellNOE(in::Integer, out::Integer, σ = tanh) =
 RNNCellNOE(randn(out, in), randn(out), σ)

function (m::RNNCellNOE)(h, x)
   h3, h2, h1 = h
   σ, W, b = m.σ, m.W, m.b
   # @show h1
   # @show h2
   # @show h3
   y = σ.(W*x .+ b)
   #@show y
   return (h[2:end]..., y), y
end

rnn_noe = RNNCellNOE(10, 1, tanh)
rnn = Flux.Recur(rnn_noe, (0.0, 0.0, 0.0))
layer1 = Dense(4, 10, σ)


model(x) = rnn(layer1(vcat(x, rnn.state...)))
X_train = hcat([0], rand(1, 99))
y_train =  rand(1, 100)
model.(X_train)

ps = Flux.params(layer1, rnn_noe)
loss(x, y) = Flux.mse(model.(x), y)
opt = ADAM(0.001, (0.9, 0.999))
evalcb() = @show(loss(X_val, y_val))

@time Flux.train!(loss, ps, [(X_train, y_train)], opt)

I don´t know if you meant this with using the Recur- function.

Actually now I get an error in interface.jl in the function pullback in the line 164 y, back = _pullback(cx, f)' when I call theFlux.train!()'-function. I also tried to debugg it, but it jumped not to every function of the stacktrace, so I haven´t got a clue what the error is all about.

ERROR: MethodError: no method matching -(::Array{Float64,1}, ::Float64)

Ooh I know that one! Somewhere, you are subtracting a Float64 from and Array of Float64. If that was .-, it would work fine, but since it is just -, you get the error. The only place I see a possible subtraction is inside Flux.mse, so try this:

loss(x, y) = Flux.mse(model.(x)[1], y)

To return a single Float instead of the array with one element.

2 Likes

Thanks. That was it :slight_smile:

Can someone (maybe @contradict or @DrChainsaw ) explain me what I have to do that the above code works with the latest Flux version 0.12.1?

In the latest version the RNNCell constains a state0 variable, but I don’t see in the code where it is updated by the model output. I would expect that it should be done inside functor.

struct RNNCell{F,A,V,S}
  σ::F
  Wi::A
  Wh::A
  b::V
  state0::S
end

function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
  σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
  h = σ.(Wi*x .+ Wh*h .+ b)
  sz = size(x)
  return h, reshape(h, :, sz[2:end]...)
end

reset!(m::Recur) = (m.state = m.cell.state0)

state0 is to store the initial state so reset! doesn’t need parameters. The recurrent state is managed by Recur. Unless you want to do some tricky state management yourself, you probably want to use the RNN layer rather than RNNCell.

Actually I would like to update the above code of RNNCellNOE

so that I can use Flux.reset!(model). Therefore, I think, that I have to add the state0 to the struct of RNNCellNOE and create an outer constructor to initialize the states. Furthermore, I’m not sure how the state0 will updated. This is what I meant with

IIRC state0 is never updated because you need a clean state to reset the Recur state with. Even though it will be returned by functor (which params calls under the hood), it will never have a gradient (and thus no gradient update) because it’s never used in the forward pass.