I would like to use DiffEqFlux to learn a discrete dynamic equation x(k+1) = A * x(k) + B, where A(u(k)) and B(u(k)) are time-varying matrix and vector which depend on the input u(k) at the time step k. The sequence of input u is known ( it is measurements of my system). I think that I can use a simple Euler time scheme to get this discrete dynamic in DiffEqFlux.
If we call ndim the size of the state vector, then we are trying to learn ndim*ndim( for A) + ndim(for B) components depending on u.
However, I don’t know how to incorporate the discrete sequence of input u into the model. I think it should be passed in some way as a parameter.
Toy problem with ndim = 2
using LinearAlgebra using Flux, DiffEqFlux, DifferentialEquations #### Create data # Input u has 10 lines - think of it as 10 sensors sensed at 100 different time steps train_u = rand(10,101) #### How should I pass the input sequence into the model ? pinput = Flux.params(train_u) # State x ndim = 2 train_x = rand(ndim,101) # Create time span Δt = 0.1 tspan = (0.0, 10.0) trange = 0.0:Δt:10.0 # Create A and B A = zeros(ndim,ndim) B = zeros(ndim) #### We want to learn a model ẋ = A(u)*x + B(u) where x is the state variable and u is a time varying input sensed at finite instant # dudt will learn the components of A and B based on the input vector u dudt = Chain(Dense(10,ndim*ndim+ndim,tanh)) ps = Flux.params(dudt) ## function to construct A and B from the neural network function state_dyn!(A, B, NN, ndim) A .= reshape(NN[1:ndim*ndim],(ndim, ndim)) B .= NN[ndim*ndim+1:end] return A, B end #### Define dynamical equation function DuDt(du,u,p,t) state_dyn!(A, B, dudt(pinput[:,1+ceil(Int, t/Δt)]), ndim) du .= A * u + B end ### Not sure about this step p = (pinput, ps) #### FluxDiffEq n_ode = x->neural_ode(DuDt,x,tspan,Euler(),dt = Δt, saveat=trange,reltol=1e-7,abstol=1e-9, params = p) opt = ADAM(0.1) data = (train_u, train_x) function predict_n_ode() n_ode(u0) end loss_n_ode() = sum(abs2,train_x .- predict_n_ode()) Flux.train!(loss_n_ode, ps, data, opt, cb = cb)
Thanks in advance for your help,