Extrapolation of a NeuralODE using `DiffEqFlux`

Hey folks. So I have trained a NeuralODE using DiffEqFlux, meaning that I have optimized the parameters of a neural network to solve this ODE. I have a dataset that spans 31 years with yearly measurements, and I trained the model on the first 15 years. So the next task is to see how well the neuralode extrapolates to the subsequent 16 years that were not trained on.

I am having trouble figuring out the code for this, or at least getting it to run. Here is a limited snippet. I did not include all of the neuralODE setup, but I can add it if people think it will help. I just wanted to keep the code as minimal as possible.

dudt2 = FastChain(FastDense(6, 50, tanh),
                  FastDense(50, 50, tanh),
                  FastDense(50, 50, tanh),
                  FastDense(50, 50, tanh),
                  FastDense(50, 50, tanh),
                  FastDense(50, 50, tanh),
                  FastDense(50, 50, tanh),
                  FastDense(50, 50, tanh),
                  FastDense(50, 50, tanh),
                  FastDense(50, 6))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, prob_neuralode.p,
                                          cb = callback_plot, maxiters=20)


extrapolation_problem = NeuralODE(x -> dudt2(x), full_data[:, 1], (1.0, 31.0), prob_neuralode.p)

res_extrapolation = solve(extrapolation_problem, Tsit5(), saveat=1.0)

So the extrapolation_problem seems to be okay, but running the solve generates this error message.

ERROR: MethodError: no method matching init(::NeuralODE{var"#18#19", Vector{Any}, Flux.var"#66#68"{var"#18#19"}, Vector{Float64}, Tuple{Tuple{Float64, Float64}, Vector{Float32}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; saveat=1.0)
Closest candidates are:
  init(::PDEProblem, ::SciMLBase.DEAlgorithm, ::Any...; kwargs...) at ~/.julia/packages/DiffEqBase/uclGA/src/solve.jl:138
  init(::SciMLBase.DEProblem, ::Any...; kwargs...) at ~/.julia/packages/DiffEqBase/uclGA/src/solve.jl:32
  init(::LinearProblem, ::Any...; kwargs...) at ~/.julia/packages/LinearSolve/xztQN/src/common.jl:70
Stacktrace:
 [1] solve(::Function, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, F)
   @ CommonSolve ~/.julia/packages/CommonSolve/alZRX/src/CommonSolve.jl:3
 [2] top-level scope
   @ ~/Dropbox/sandbox/julia_gend_univ/dev/neuralode_model/basic_node_6_group um.jl:187

In the example above, full_data is the full 31 year dataset. So I want to plot the performance of the neuralode on both the original 15 years of of training data and the subsequent 16 years, hence the timespan from =(1.0, 31.0). The prob_neuralode.p are supposed to be the parameters of the neural network dudt2.

Any suggestions for how to fix this?

Why not

extrapolation_problem = ODEProblem((u,p,t)->dudt2(u), full_data[:, 1], tspan, result_neuralode.minimizer)
res_extrapolation = solve(extrapolation_problem, Tsit5(), saveat=1.0)

?

?

Hey @ChrisRackauckas Thanks, yes. I totally missed the (u, p, t) part. I just tried your code, and had to make one small adjustment. I had to change dudt2(u) to dudt2(u, p), since dudt2 is a FastChain. Now it all works.

The code that works looks like this,

extrapolation_problem = ODEProblem((u,p,t)->dudt2(u,p), full_data[:, 1], (1.0, 31.0), result_neuralode.minimizer)
res_extrapolation = solve(extrapolation_problem, Tsit5(), saveat=1.0)

Thanks again for all of your help.