Flux.jl Training loop where data sequentially depends on model output


Hey all!

I’m interested in writing some “closed loop” generative models that receive their prediction error for the previous input as their next input. Is there a way to do this in Flux.jl’s Train! loop?




I’m no expert on this, although I have started to look at Flux. I’m not 100% sure whether I understand your question, but it may be related to my own main interest in Flux. My interest is to train models for dynamic systems (typically, inputs are denoted u_t and outputs are denoted y_t). If that is what you are interested in, I think there are two ways:

  • Use Feedforward Neural Network (FNN, Dense()), with at least one hidden layer, say model = Chain(Dense(Nu,nh,tanh),Dense(Nh,Ny)) mapping u_t, u_{t-1}, \ldots, u_{t-n_u}, y_{t-1}, y_{t-2}, \ldots, y_{t-n_y} to y_t, with N_u = \dim u_t\cdot (n_u+1) + \dim y_t \cdot n_y and N_y = \dim y_t. This would be some NARX structure. Making the number of hidden nodes (nh) large should make it possible to describe any system with sufficient accuracy, I would guess. Here, I have indicated using the tanh function as activation function; without specifying activation function, this defaults to identity.
  • Combine an FNN with a Recurrent Neural Network (RNN). My understanding is that an RNN essentially is a state space model of form y_t = \sigma(W_\mathrm{f}y_{t-1} + Wx_{t-1} + b) where \sigma is the activation function (e.g. tanh or any other suitable function) and x is the input (which is denoted u in control engineering). To allow for more flexibility, one could add an FNN at the input to describe more nonlinearities, and maybe (but I’m not sure) a linear or nonlinear layer at the output. A simple “state space” model could probably be model = Chain(Dense(nu,nh,tahn),RNN(nh,nx,tanh),Dense(nx,ny)), where nh is the number of hidden nodes and nh is the number of states.

As I indicate, I haven’t tested this, so I don’t know whether it works. I’ll test it out the next couple of weeks, when I have time.