Hey folks. So I have trained a NeuralODE
using DiffEqFlux
, meaning that I have optimized the parameters of a neural network to solve this ODE. I have a dataset that spans 31 years with yearly measurements, and I trained the model on the first 15 years. So the next task is to see how well the neuralode extrapolates to the subsequent 16 years that were not trained on.
I am having trouble figuring out the code for this, or at least getting it to run. Here is a limited snippet. I did not include all of the neuralODE setup, but I can add it if people think it will help. I just wanted to keep the code as minimal as possible.
dudt2 = FastChain(FastDense(6, 50, tanh),
FastDense(50, 50, tanh),
FastDense(50, 50, tanh),
FastDense(50, 50, tanh),
FastDense(50, 50, tanh),
FastDense(50, 50, tanh),
FastDense(50, 50, tanh),
FastDense(50, 50, tanh),
FastDense(50, 50, tanh),
FastDense(50, 6))
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, prob_neuralode.p,
cb = callback_plot, maxiters=20)
extrapolation_problem = NeuralODE(x -> dudt2(x), full_data[:, 1], (1.0, 31.0), prob_neuralode.p)
res_extrapolation = solve(extrapolation_problem, Tsit5(), saveat=1.0)
So the extrapolation_problem
seems to be okay, but running the solve
generates this error message.
ERROR: MethodError: no method matching init(::NeuralODE{var"#18#19", Vector{Any}, Flux.var"#66#68"{var"#18#19"}, Vector{Float64}, Tuple{Tuple{Float64, Float64}, Vector{Float32}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; saveat=1.0)
Closest candidates are:
init(::PDEProblem, ::SciMLBase.DEAlgorithm, ::Any...; kwargs...) at ~/.julia/packages/DiffEqBase/uclGA/src/solve.jl:138
init(::SciMLBase.DEProblem, ::Any...; kwargs...) at ~/.julia/packages/DiffEqBase/uclGA/src/solve.jl:32
init(::LinearProblem, ::Any...; kwargs...) at ~/.julia/packages/LinearSolve/xztQN/src/common.jl:70
Stacktrace:
[1] solve(::Function, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, F)
@ CommonSolve ~/.julia/packages/CommonSolve/alZRX/src/CommonSolve.jl:3
[2] top-level scope
@ ~/Dropbox/sandbox/julia_gend_univ/dev/neuralode_model/basic_node_6_group um.jl:187
In the example above, full_data
is the full 31 year dataset. So I want to plot the performance of the neuralode on both the original 15 years of of training data and the subsequent 16 years, hence the timespan from =(1.0, 31.0)
. The prob_neuralode.p
are supposed to be the parameters of the neural network dudt2
.
Any suggestions for how to fix this?