In this code, we are looking to see if we can use DataDrivenDiffEq.jl to accurate predict an equation for a time dependent differential equation. To test this, we have been creating an equation, adding noise to the solution, and seeing if we can get it back using the package. We have been able to get a good fit, but the equations it gives us back have added terms in them. Here is the code followed by the output.
using DataDrivenDiffEq, ModelingToolkit, OrdinaryDiffEq
using DataDrivenSparse, LinearAlgebra, StableRNGs, Plots
rng = StableRNG(1000)
# Work on preventing overfitting
function f(u, p, t)
x, y = u
dx = 2.0 * x * y
dy = 1
return [dx, dy]
end
# Setting initial conditions
u0 = [1.0; 0]
tspan = (0.0, 2.0)
dt = 0.0001
# Solving ODE Problem
prob = ODEProblem(f, u0, tspan)
sol = solve(prob, Tsit5(), saveat=dt)
# Adding noise to solution
X = sol[:, :] + 0.2 .* randn(rng, size(sol))
ts = sol.t
prob = ContinuousDataDrivenProblem(X, ts, GaussianKernel(),)
@variables u[1:2]
u = collect(u)
h = Num[polynomial_basis(u, 2); u]
basis = Basis(h, u)
sampler = DataProcessing(split=0.8, shuffle=true, batchsize=25, rng=rng)
lambdas = exp10.(-10:0.1:0)
opt = STLSQ(lambdas)
res = solve(prob, basis, opt, options=DataDrivenCommonOptions(data_processing=sampler, digits=2))
system = get_basis(res)
params = get_parameter_map(system)
# Displaying results
println(system)
println(params)
display(plot(plot(prob), plot(res), layout=(1, 2)))
Model #basis#311 with 2 equations
States : u[1] u[2]
Parameters : 6
Independent variable: t
Equations
Differential(t)(u[1]) = p₁ + p₃*(u[2]^2) + p₄u[1] + p₅u[2] + p₂*u[1]*u[2]
Differential(t)(u[2]) = p₆
Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}[p₁ => -1.38, p₂ => 1.38, p₃ => -1.6, p₄ => 1.27, p₅ => 1.17, p₆ => 0.99]
Ideally, we should be getting a system of equations that look exactly like what we put in, but it is adding terms. To try to get a better system of equations out, we have decreased dt and limited the equation we get back to a maximum of degree 2. Is there anything else we should do to try to get a more accurate output?