I am working on implementing the ODELSTM cell https://github.com/SciML/DiffEqFlux.jl/issues/422
The cell expects hidden states and a tuple of a feature vector and the elapsed time between the observations (or a 2D feature array and a vector of lags for batched inputs).
For each observation, an ODE gets solved with a tspan of (0,elapsed).
I was going to solve an EnsembleProblem where I set u0 and tspan inside the prob_func if the input is batched.
I get an error if I try to train a classifier (e.g. Chain(Recur(ODELSTMCell()),softmax)
) using Flux.train! (batchsize=32 in my example)
BoundsError: attempt to access (32,)
at index [0]
in top-level scope at julia-pr/LTC.jl/src/ODELSTMbatched.jl:413
in start_opt at julia-pr/LTC.jl/src/ODELSTMbatched.jl:404
in train! at Flux/05b38/src/optimise/train.jl:78
in #train!#12 at Flux/05b38/src/optimise/train.jl:80
in macro expansion at Juno/n6wyj/src/progress.jl:119
in macro expansion at Flux/05b38/src/optimise/train.jl:82
in gradient at Zygote/chgvX/src/compiler/interface.jl:54
in at Zygote/chgvX/src/compiler/interface.jl:177
in at Zygote/chgvX/src/compiler/interface2.jl
in #14 at Flux/05b38/src/optimise/train.jl:83
in #1681#back at ZygoteRules/6nssF/src/adjoint.jl:49
in at Zygote/chgvX/src/lib/lib.jl:175
in at Zygote/chgvX/src/compiler/interface2.jl
in #156 at julia-pr/LTC.jl/src/ODELSTMbatched.jl:404
in at Zygote/chgvX/src/compiler/interface2.jl
in lossf at julia-pr/LTC.jl/src/ODELSTMbatched.jl:366
in at Zygote/chgvX/src/compiler/interface2.jl
in broadcasted at base/broadcast.jl:1257
in #1681#back at ZygoteRules/6nssF/src/adjoint.jl:49
in at Zygote/chgvX/src/lib/lib.jl:175
in #3828#back at ZygoteRules/6nssF/src/adjoint.jl:49
in at Zygote/chgvX/src/lib/broadcast.jl:140
in map at base/abstractarray.jl:2248
in collect at base/array.jl:686
in iterate at base/generator.jl:47
in at base/generator.jl:36
in #1057 at Zygote/chgvX/src/lib/broadcast.jl:140
in at Zygote/chgvX/src/compiler/interface2.jl
in Chain at Flux/05b38/src/layers/basic.jl:38
in at Zygote/chgvX/src/compiler/interface2.jl
in applychain at Flux/05b38/src/layers/basic.jl:36
in at Zygote/chgvX/src/compiler/interface2.jl
in Recur at Flux/05b38/src/layers/recurrent.jl:36
in #1681#back at ZygoteRules/6nssF/src/adjoint.jl:49
in #145 at Zygote/chgvX/src/lib/lib.jl:175
in at Zygote/chgvX/src/compiler/interface2.jl
in ODELSTMCellB at julia-pr/LTC.jl/src/ODELSTMbatched.jl:342
in at Zygote/chgvX/src/compiler/interface2.jl
in CTRNNCellB at julia-pr/LTC.jl/src/ODELSTMbatched.jl:221
in at Zygote/chgvX/src/compiler/interface2.jl
in solve##kw at DiffEqBase/gLFRA/src/solve.jl:97
in #1681#back at ZygoteRules/6nssF/src/adjoint.jl:49
in at Zygote/chgvX/src/lib/lib.jl:175
in at Zygote/chgvX/src/compiler/interface2.jl
in #solve#459 at DiffEqBase/gLFRA/src/solve.jl:100
in at ZygoteRules/6nssF/src/adjoint.jl:49
in #145 at Zygote/chgvX/src/lib/lib.jl:175
in at Zygote/chgvX/src/compiler/interface2.jl
in __solve##kw at DiffEqBase/gLFRA/src/ensemble/basic_ensemble_solve.jl:103
in at Zygote/chgvX/src/compiler/interface2.jl
in #__solve#359 at DiffEqBase/gLFRA/src/ensemble/basic_ensemble_solve.jl:110
in at ZygoteRules/6nssF/src/adjoint.jl:49
in at DiffEqBase/gLFRA/src/zygote.jl:61
in collect at base/array.jl:686
in iterate at base/generator.jl:47
in at base/none
in getindex at base/tuple.jl:24
Inside EnsembleSolution_adjoint(p̄::AbstractArray{T,N}) where {T,N}
https://github.com/SciML/DiffEqBase.jl/blob/master/src/zygote.jl#L61 p̄ is a Vector, thus size(p̄)[end-1]
raises an error.
If I change it to size(p,1)
then if fails because of Val(N-2)
(Val(-1)).
My code so far: https://github.com/lungd/Flux.jl/blob/ODELSTM/WIP_odelstm_batched.jl
Any help would be appreciated!