ODE-LSTM implementation

I am working on implementing the ODELSTM cell https://github.com/SciML/DiffEqFlux.jl/issues/422

The cell expects hidden states and a tuple of a feature vector and the elapsed time between the observations (or a 2D feature array and a vector of lags for batched inputs).
For each observation, an ODE gets solved with a tspan of (0,elapsed).

I was going to solve an EnsembleProblem where I set u0 and tspan inside the prob_func if the input is batched.

I get an error if I try to train a classifier (e.g. Chain(Recur(ODELSTMCell()),softmax)) using Flux.train! (batchsize=32 in my example)

BoundsError: attempt to access (32,)
  at index [0]
in top-level scope at julia-pr/LTC.jl/src/ODELSTMbatched.jl:413
in start_opt at julia-pr/LTC.jl/src/ODELSTMbatched.jl:404
in train! at Flux/05b38/src/optimise/train.jl:78 
in #train!#12 at Flux/05b38/src/optimise/train.jl:80
in macro expansion at Juno/n6wyj/src/progress.jl:119 
in macro expansion at Flux/05b38/src/optimise/train.jl:82 
in gradient at Zygote/chgvX/src/compiler/interface.jl:54
in  at Zygote/chgvX/src/compiler/interface.jl:177
in  at Zygote/chgvX/src/compiler/interface2.jl
in #14 at Flux/05b38/src/optimise/train.jl:83 
in #1681#back at ZygoteRules/6nssF/src/adjoint.jl:49 
in  at Zygote/chgvX/src/lib/lib.jl:175
in  at Zygote/chgvX/src/compiler/interface2.jl
in #156 at julia-pr/LTC.jl/src/ODELSTMbatched.jl:404 
in  at Zygote/chgvX/src/compiler/interface2.jl
in lossf at julia-pr/LTC.jl/src/ODELSTMbatched.jl:366 
in  at Zygote/chgvX/src/compiler/interface2.jl
in broadcasted at base/broadcast.jl:1257 
in #1681#back at ZygoteRules/6nssF/src/adjoint.jl:49 
in  at Zygote/chgvX/src/lib/lib.jl:175
in #3828#back at ZygoteRules/6nssF/src/adjoint.jl:49 
in  at Zygote/chgvX/src/lib/broadcast.jl:140
in map at base/abstractarray.jl:2248 
in collect at base/array.jl:686
in iterate at base/generator.jl:47 
in  at base/generator.jl:36
in #1057 at Zygote/chgvX/src/lib/broadcast.jl:140 
in  at Zygote/chgvX/src/compiler/interface2.jl
in Chain at Flux/05b38/src/layers/basic.jl:38 
in  at Zygote/chgvX/src/compiler/interface2.jl
in applychain at Flux/05b38/src/layers/basic.jl:36 
in  at Zygote/chgvX/src/compiler/interface2.jl
in Recur at Flux/05b38/src/layers/recurrent.jl:36 
in #1681#back at ZygoteRules/6nssF/src/adjoint.jl:49 
in #145 at Zygote/chgvX/src/lib/lib.jl:175 
in  at Zygote/chgvX/src/compiler/interface2.jl
in ODELSTMCellB at julia-pr/LTC.jl/src/ODELSTMbatched.jl:342 
in  at Zygote/chgvX/src/compiler/interface2.jl
in CTRNNCellB at julia-pr/LTC.jl/src/ODELSTMbatched.jl:221 
in  at Zygote/chgvX/src/compiler/interface2.jl
in solve##kw at DiffEqBase/gLFRA/src/solve.jl:97 
in #1681#back at ZygoteRules/6nssF/src/adjoint.jl:49 
in  at Zygote/chgvX/src/lib/lib.jl:175
in  at Zygote/chgvX/src/compiler/interface2.jl
in #solve#459 at DiffEqBase/gLFRA/src/solve.jl:100 
in  at ZygoteRules/6nssF/src/adjoint.jl:49
in #145 at Zygote/chgvX/src/lib/lib.jl:175 
in  at Zygote/chgvX/src/compiler/interface2.jl
in __solve##kw at DiffEqBase/gLFRA/src/ensemble/basic_ensemble_solve.jl:103 
in  at Zygote/chgvX/src/compiler/interface2.jl
in #__solve#359 at DiffEqBase/gLFRA/src/ensemble/basic_ensemble_solve.jl:110 
in  at ZygoteRules/6nssF/src/adjoint.jl:49
in  at DiffEqBase/gLFRA/src/zygote.jl:61
in collect at base/array.jl:686 
in iterate at base/generator.jl:47 
in  at base/none
in getindex at base/tuple.jl:24

Inside EnsembleSolution_adjoint(p̄::AbstractArray{T,N}) where {T,N} https://github.com/SciML/DiffEqBase.jl/blob/master/src/zygote.jl#L61 p̄ is a Vector, thus size(p̄)[end-1] raises an error.
If I change it to size(p,1) then if fails because of Val(N-2) (Val(-1)).

My code so far: https://github.com/lungd/Flux.jl/blob/ODELSTM/WIP_odelstm_batched.jl

Any help would be appreciated!

Did you see this tutorial on fitting drift and diffusion of an SDE via ensembles? https://diffeqflux.sciml.ai/dev/examples/optimization_sde/ Is how you’re handling the parameters different?

1 Like