Hi,
I am working on solving some differential equations and noticed some discrepancies in the convergence of the RK4 solver. I tested the error by finding the difference between the solver solution and analytical results and computing the absolute maximum of the error vector. To check the method itself, I coded up the RK4 method from scratch and found that it seemed to converge on smaller values of dt.
Here is a MWE of the method I am using on a much simpler problem.
# ODE is \dot{u} = velocity
# solution is u(t) = velocity*t
# using DifferentialEquations
# using DataFrames
# using Plots
u0 = Float32[0.0] # initial conditions
tspan = (0.0f0, 1.0f4) # time range
model_params = [] # RHS is not parameterized
velocity = 0.05 # RHS value
function RHS(u, model_params, t)
return [velocity]
end
## Custom RK4 solver
function rk4_step(f, u, p, t, dt)
k1 = f(u, p, t)
k2 = f(u + dt/2 * k1, p, t + dt/2)
k3 = f(u + dt/2 * k2, p, t + dt/2)
k4 = f(u + dt * k3, p, t + dt)
return u + dt/6 * (k1 + 2*k2 + 2*k3 + k4)
end
function rk4_solve(f, u0, tspan, p, dt)
t0, tf = tspan
t = t0
u = u0
ts = [t]
us = [u]
while t < tf
u = rk4_step(f, u, p, t, dt)
t += dt
push!(ts, t)
push!(us, u)
end
return ts, us
end
dt = [100,10,1,0.1,0.01]
error_julia_rk4 = []
error_custom_rk4 = []
for dts in dt
## Solution and error from custom solver
custom_tsteps ,custom_sol = rk4_solve(RHS, u0, tspan, model_params, dts)
sol_phi = [u[1] for u in custom_sol]
x = sol_phi .- custom_tsteps*velocity
error_custom = maximum(abs.(x))
## Solution and error from julia solver
julia_sol = solve(prob, RK4(),dt = dts, adaptive=false)
y = julia_sol[1,:] .- julia_sol.t*velocity
error_julia = maximum(abs.(y))
push!(error_julia_rk4, error_julia)
push!(error_custom_rk4, error_custom)
end
The output I get is the following :
julia> error_julia_rk4
5-element Vector{Any}:
0.0
0.0
0.048553466796875
0.40500488281247726
2.3095886230468636
julia> error_custom_rk4
5-element Vector{Any}:
0.0
0.0
1.2207031261368684e-5
3.0517578125e-5
3.662109378410605e-5
I’m not sure what I’m missing in my code. I would appreciate any help and suggestions.
Here are the package versions I’m using :
[0c46a032] DifferentialEquations v7.6.0
[91a5bcdd] Plots v1.39.0
[a93c6f00] DataFrames v1.3.6
Please let me know if I can provide more information.