DiffEqFlux with discrete time-dependent input

Hello everyone,

I would like to use DiffEqFlux to learn a discrete dynamic equation x(k+1) = A * x(k) + B, where A(u(k)) and B(u(k)) are time-varying matrix and vector which depend on the input u(k) at the time step k. The sequence of input u is known ( it is measurements of my system). I think that I can use a simple Euler time scheme to get this discrete dynamic in DiffEqFlux.

If we call ndim the size of the state vector, then we are trying to learn ndim*ndim( for A) + ndim(for B) components depending on u.

However, I don’t know how to incorporate the discrete sequence of input u into the model. I think it should be passed in some way as a parameter.

Toy problem with ndim = 2


using LinearAlgebra
using Flux, DiffEqFlux, DifferentialEquations

#### Create data

# Input u has 10 lines - think of it as 10 sensors sensed at 100 different time steps
train_u = rand(10,101)
#### How should I pass the input sequence into the model ?
pinput = Flux.params(train_u)

# State x
ndim = 2
train_x = rand(ndim,101)

# Create time span
Δt = 0.1
tspan = (0.0, 10.0)
trange = 0.0:Δt:10.0

# Create A and B
A = zeros(ndim,ndim)
B = zeros(ndim)


#### We want to learn a model ẋ = A(u)*x + B(u) where x is the state variable and u is a time varying input sensed at finite instant

# dudt will learn the components of A and B based on the input vector u 
dudt = Chain(Dense(10,ndim*ndim+ndim,tanh))

ps = Flux.params(dudt)

## function to construct A and B from the neural network
function state_dyn!(A, B, NN, ndim)
    A .= reshape(NN[1:ndim*ndim],(ndim, ndim))
    B .= NN[ndim*ndim+1:end]
    return A, B
end

#### Define dynamical equation
function DuDt(du,u,p,t)
    state_dyn!(A, B, dudt(pinput[:,1+ceil(Int, t/Δt)]), ndim)
    du .= A * u + B
end

### Not sure about this step
 p = (pinput, ps)

#### FluxDiffEq
n_ode = x->neural_ode(DuDt,x,tspan,Euler(),dt = Δt, saveat=trange,reltol=1e-7,abstol=1e-9, params = p)

opt = ADAM(0.1)
data = (train_u, train_x)
function predict_n_ode()
  n_ode(u0)
end
loss_n_ode() = sum(abs2,train_x .- predict_n_ode())

Flux.train!(loss_n_ode, ps, data, opt, cb = cb)

Thanks in advance for your help,

If you want to learn a discrete-time model there is no need to use neural ode, just use standard Flux.

Hello,

That’s what I 've tried but I don’t know how to handle the back-prop to learn the parameters.
Here the code that I wrote, I get an error in the training

using LinearAlgebra
using Statistics
using Flux, DiffEqFlux, DifferentialEquations

#### Create data

# Input u has 10 lines - think of i as 10 sensors sensed at 100 different time steps
train_u = rand(Float32, (10,101))
# State x
ndim = 2
train_x = rand(Float32, (ndim,101))

# Set time interval
Δt = 0.1
tf = 10.0
tspan = (0.0, tf)
trange = 0.0:Δt:tf

# Create A and B
A = zeros(Float32, (ndim,ndim))
B = zeros(Float32, ndim)

#### We want to learn a model x(k+1) = A(u(k))*x(k) + B(u(k)) where x is the state variable and u is a time varying input sensed at finite instant

Model = Chain(Dense(10,ndim*ndim+ndim,tanh))

# function to construct A and B from the neural network
function state_dyn!(A, B, NN, ndim)
    A .= reshape(NN[1:ndim*ndim],(ndim, ndim))
    B .= NN[ndim*ndim+1:end]
    return A, B
end

# Euler_Forward returns the prediction xt given the sequence of input u
function Euler_Forward(x₀,tf, Δt, input, model, ndim)
    N = ceil(Int,tf/Δt)
    xt = zeros(Float32, ndim,N+1)
    A = zeros(Float32, ndim, ndim)
    B = zeros(Float32, ndim)
    
    # Set initial state
    xt[:,1] = x₀ 
    
    # Time marching
    for k = 1:N
        state_dyn!(A, B, model(input[:,k]).data, ndim)
        xt[:,k+1] .= A * xt[:,k] + B
    end
    
    return xt
end

# Define tools for optimization

opt = ADAM()

function loss(x, y)
    Flux.reset!(Model)
    l=Flux.mse(Euler_Forward(y[:,1], tf, Δt, x, Model, 2), y)
    return l
end

data = zip(train_u, train_x)

params = Flux.Params(Model)

Flux.train!(loss, params, data, opt)

Output of Flux.train!(loss, params, data, opt):

MethodError: no method matching getindex(::Float32, ::Colon, ::Int64)
Closest candidates are:
  getindex(::Number, !Matched::Integer...) at number.jl:82
  getindex(::Number) at number.jl:75
  getindex(::Number, !Matched::Integer) at number.jl:77
  ...

Stacktrace:
 [1] loss(::Float32, ::Float32) at ./In[164]:3
 [2] (::getfield(Flux.Optimise, Symbol("##15#21")){typeof(loss),Tuple{Float32,Float32}})() at /home/mat/.julia/packages/Flux/qXNjB/src/optimise/train.jl:72
 [3] gradient_(::getfield(Flux.Optimise, Symbol("##15#21")){typeof(loss),Tuple{Float32,Float32}}, ::Tracker.Params) at /home/mat/.julia/packages/Tracker/RRYy6/src/back.jl:97
 [4] #gradient#24(::Bool, ::Function, ::Function, ::Tracker.Params) at /home/mat/.julia/packages/Tracker/RRYy6/src/back.jl:164
 [5] gradient at /home/mat/.julia/packages/Tracker/RRYy6/src/back.jl:164 [inlined]
 [6] macro expansion at /home/mat/.julia/packages/Flux/qXNjB/src/optimise/train.jl:71 [inlined]
 [7] macro expansion at /home/mat/.julia/packages/Juno/TfNYn/src/progress.jl:133 [inlined]
 [8] #train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::Function, ::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{Float32,2},Array{Float32,2}}}, ::ADAM) at /home/mat/.julia/packages/Flux/qXNjB/src/optimise/train.jl:69
 [9] train!(::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{Float32,2},Array{Float32,2}}}, ::ADAM) at /home/mat/.julia/packages/Flux/qXNjB/src/optimise/train.jl:67
 [10] top-level scope at In[168]:1