How can I pre train the net? let me explain. Neural ode is very slow if I don’t give the network information. My ode and exogenous variables are the ones seen in the file. how can I incorporate them to my structure to speed up the train?
using DifferentialEquations, Plots, Flux,Optim, DiffEqFlux, DataInterpolations,Random, ComponentArrays, Lux
using Optimization, OptimizationOptimisers, OptimizationOptimJL,OptimizationNLopt
rng = Random.default_rng()
using CSV
using DataFrames
using Plots
using Flux
using Statistics: mean,std,median,quantile
using DiffEqFlux, Optimization, OptimizationOptimJL,Plots
using ComponentArrays, Lux, DiffEqFlux, Optimization, OptimizationPolyalgorithms, DifferentialEquations, Plots
using DiffEqFlux: group_ranges
using StatsPlots
# Load the data
df = CSV.read("real_values_ex_00_fixed_time_step.csv", DataFrame)
df=repeat(df, outer=1)
registered_gas_flow = df[1:10:end, :2]
Registered_Temperature = df[1:10:end, :3]
tsteps= df[1:10:end, :1]
gas_flow = LinearInterpolation(registered_gas_flow,tsteps);
function ext_flow(tsteps)
return gas_flow(tsteps)
end
function water_temp(tsteps)
return Temperature_h20(tsteps)
end
#create a 3600 time vector
function RC!(du,u,p,t)
A,B,C,D = p
P= ext_flow(t)
du[1] =(B * P - C + D*(20- u[1]))/A
end
u0= [20.0]
tspan= (0.0f0,3600.0f0)
p= [66.896, 50e6, 100,0.20]
A,B,C,D=p
gas_flow = LinearInterpolation(registered_gas_flow,tsteps);
Temperature_h20=LinearInterpolation(Registered_Temperature,tsteps);
prob= ODEProblem(RC!, u0, tspan, p)
ode_data =Array(solve(prob,Tsit5(),saveat=tsteps,reltol=1e-8,abstol=1e-8))
dudt2 = Lux.Chain(
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
function predict_neuralode(p)
Array(prob_neuralode(u0, p, st)[1])
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
# Do not plot by default for the documentation
# Users should change doplot=true to see the plots callbacks
callback = function (p, l, pred; doplot = false)
println(l)
# plot current prediction against data
if doplot
plt = scatter(tsteps, ode_data[1,:], label = "data")
scatter!(plt, tsteps, pred[1,:], label = "prediction")
display(plot(plt))
end
return false
end
pinit = ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot=true)
# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
Adam(0.05),
callback = callback,
maxiters = 300)
optprob2 = remake(optprob,u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(optprob2,
Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
allow_f_increases = false)
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)