Hello everyone,
I would like to use DiffEqFlux to learn a discrete dynamic equation x(k+1) = A * x(k) + B, where A(u(k)) and B(u(k)) are time-varying matrix and vector which depend on the input u(k) at the time step k. The sequence of input u is known ( it is measurements of my system). I think that I can use a simple Euler time scheme to get this discrete dynamic in DiffEqFlux.
If we call ndim the size of the state vector, then we are trying to learn ndim*ndim( for A) + ndim(for B) components depending on u.
However, I don’t know how to incorporate the discrete sequence of input u into the model. I think it should be passed in some way as a parameter.
Toy problem with ndim = 2
using LinearAlgebra
using Flux, DiffEqFlux, DifferentialEquations
#### Create data
# Input u has 10 lines - think of it as 10 sensors sensed at 100 different time steps
train_u = rand(10,101)
#### How should I pass the input sequence into the model ?
pinput = Flux.params(train_u)
# State x
ndim = 2
train_x = rand(ndim,101)
# Create time span
Δt = 0.1
tspan = (0.0, 10.0)
trange = 0.0:Δt:10.0
# Create A and B
A = zeros(ndim,ndim)
B = zeros(ndim)
#### We want to learn a model ẋ = A(u)*x + B(u) where x is the state variable and u is a time varying input sensed at finite instant
# dudt will learn the components of A and B based on the input vector u
dudt = Chain(Dense(10,ndim*ndim+ndim,tanh))
ps = Flux.params(dudt)
## function to construct A and B from the neural network
function state_dyn!(A, B, NN, ndim)
A .= reshape(NN[1:ndim*ndim],(ndim, ndim))
B .= NN[ndim*ndim+1:end]
return A, B
end
#### Define dynamical equation
function DuDt(du,u,p,t)
state_dyn!(A, B, dudt(pinput[:,1+ceil(Int, t/Δt)]), ndim)
du .= A * u + B
end
### Not sure about this step
p = (pinput, ps)
#### FluxDiffEq
n_ode = x->neural_ode(DuDt,x,tspan,Euler(),dt = Δt, saveat=trange,reltol=1e-7,abstol=1e-9, params = p)
opt = ADAM(0.1)
data = (train_u, train_x)
function predict_n_ode()
n_ode(u0)
end
loss_n_ode() = sum(abs2,train_x .- predict_n_ode())
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)
Thanks in advance for your help,