Dear all,

First of all, thank you very much for this supportive environment here. As suggested, I moved my question to a separate thread.

Unfortunately, I am neither very familiar with neural nets nor with details in Julia. Therefore, I face problems when trying to program a forecast solution which incorporates an ODE NN.

I managed to modify the example here

https://docs.sciml.ai/DiffEqFlux/dev/examples/neural_ode/

so that I calculates a logistic differential equation (du1 = a1u1+a2u1^2 ) - a first order homogenous differential equation. Optimizing works well and the comparison with the exact solution and its direction field is good. So far, from the ODE-solving point of view it seems to be correct. I also added the possibility that Lux.chain takes the mathematical basis from a separate function. A possibility which I found in another example code.

math_law(x) = 2.4f0.*x.-1.1f0.*x.^2

dudt2 = Lux.Chain(x → math_law(x),…

But now, I do not find a way how to proceed with a maybe trivial problem. I want to make a simple forecast starting from the end of the interval tspan. The description of the package Lux, one of the packages which are incorporated in the code, did not help me further. All I tried led to various error messages.

The difficulty, simply spoken, for me is to evaluate the NN (which represents, if I understood correctly, the ODE) to obtain a value back and to solve (to integrate) the neural ODE (using DiffEqFlux and Lux). And here I kindly ask for your support.

Many thanks in advance.

```
using Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots
gr()
u0 = Float32[0.5] # initial condition u(t=0)
datasize = 30
tspan = (0.0f0, 2.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
#______________________________________________________________________________
# logistic differential equation LogDiffEqu
function LogDiffEqu(du, u, p, t)
A = [2.4 -1.1]
du[1] = A[1]*u[1]+A[2]*u[1]^2
end
prob_LogDiffEqu = ODEProblem(LogDiffEqu, u0, tspan)
sol_LogDiffEqu = Array(solve(prob_LogDiffEqu, Tsit5(), saveat = tsteps))
#______________________________________________________________________________
# shaping information for the NN
math_law(x) = 2.0f0.*x.-1.0f0.*x.^2 # mention "f0", preliminary
# Construct the layer
dudt2 = Lux.Chain(x -> math_law(x),Lux.Dense(1, 50, tanh),Lux.Dense(50, 1))
rng = Random.default_rng()
# Initialization
p, st = Lux.setup(rng, dudt2)
#_______________________________________________________________________________
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
#______________________________________________________________________________
function predict_neuralode(p)
Array(prob_neuralode(u0, p, st)[1])
end
#______________________________________________________________________________
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2.(sol_LogDiffEqu .- pred))
return loss, pred
end
#______________________________________________________________________________
callback = function (p, l, pred; doplot = false)
# println(l) # Ausgabe der erreichten Genauigkeit
# plot current prediction against data
if doplot
plt = Plots.scatter(tsteps, sol_LogDiffEqu[1,:], label = "data")
Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
display(plot(plt))
end
return false
end
#______________________________________________________________________________
pinit = Lux.ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot=true)
#______________________________________________________________________________
# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,ADAM(0.05),callback = callback,maxiters = 100)
# Retrain using the LBFGS optimizer
optprob2 = remake(optprob,u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(optprob2,Optim.BFGS(initial_stepnorm=0.01),callback=callback,allow_f_increases = false)
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)
#______________________________________________________________________________
```