Here’s the code in a single file.
It might take me a while to actually understand what’s going on, but there are a few red flags that stand out. You should essentially never need to use eval
or invoke
latest, and the mix of DifferentialEquations
and Symbolics
with JuMP can cause issues. This code is also very far from what the C++ is doing, so it’s no wonder that there is a performance difference.
How long does their Ipopt code take to run? And is the RAM issue here because of the solve or during the setup? Asked a different way, if you remove optimize!
, is there still an issue?
using DifferentialEquations
using Symbolics
using Plots
using Parameters
using JuMP
using Ipopt
using StaticArrays
using Interpolations
using IOCapture
using JLD2
struct VarAssimMeta
vf::Function
vf_strs::Vector{String}
D::Int
P::Int
observed_dims::Vector{Int}
driven_dims::Vector{Int}
L::Int
I::Int
lb_states::Vector
ub_states::Vector
lb_params::Vector
ub_params::Vector
end
function VarAssimMeta(
vf::Function,
D::Int,
P::Int,
observed_dims::Vector{Int},
driven_dims::Vector{Int},
lb_states::Vector,
ub_states::Vector,
lb_params::Vector,
ub_params::Vector,
)
L = length(observed_dims)
I = length(driven_dims)
Symbolics.@variables u[1:D], p[1:P], f[1:I]
vf_syms = vf(u, p, f)
vf_strs = repr.(vf_syms)
VarAssimMeta(vf, vf_strs, D, P, observed_dims, driven_dims, L, I, lb_states, ub_states, lb_params, ub_params)
end
function jump_nl_substitute!(model, eqstr, prec, urec, frec)
npts = size(frec)[2]
@assert size(urec)[2] == npts
find_pats = [Regex("$x" * raw"\[(\d+)\]") for x in ["p", "u", "f"]]
repl_pats = [SubstitutionString("opt_p[\\1]"), SubstitutionString("opt_u[\\1,j]"), SubstitutionString("drive[\\1,j]")]
neweqstr = replace(eqstr, [find_pats[i] => repl_pats[i] for i in eachindex(find_pats)]...)
model_lambda_str = """(model, opt_p, opt_u, drive) -> @NLexpression(
model,
[j=1:$npts],
$neweqstr);
"""
model_lambda = eval(Meta.parse(model_lambda_str))
reshape(Base.invokelatest(model_lambda, model, prec, urec, frec), (1, npts))
end
abstract type AbstractVADiscretization end
struct SimpsonHermite <: AbstractVADiscretization end
function gen_dspe_jump_model(vam::VarAssimMeta, disc::SimpsonHermite, timepoints, data, drive_frozen)
N = length(timepoints) - 1
dt = diff(timepoints)
@unpack vf_strs, D, P, observed_dims, lb_states, ub_states, lb_params, ub_params = vam
L = length(observed_dims)
model = Model()
JuMP.@variables(model, begin
opt_x[1:D, 1:(N+1)]
opt_xm[1:D, 1:N]
0.0 <= opt_k[1:(2*N+1)] <= 1.0
opt_p[1:P]
end)
drive = vcat(drive_frozen, opt_k')
data_x = data[:, 1:2:(2*N+1)]
data_xm = data[:, 2:2:(2*N)]
drive_x = drive[:, 1:2:(2*N+1)]
drive_xm = drive[:, 2:2:(2*N)]
println("Generating timepoint and midpoint vector field subexpressions")
F = vcat([jump_nl_substitute!(model, eqstr, opt_p, opt_x, drive_x) for eqstr in vf_strs]...)
Fm = vcat([jump_nl_substitute!(model, eqstr, opt_p, opt_xm, drive_xm) for eqstr in vf_strs]...)
println("Generating Simpson residual subexpressions.")
S = @NLexpression(
model,
[i = 1:D, j = 1:N],
(opt_x[i, j] + (1 / 6) * dt[j] * F[i, j]) + (2 / 3) * dt[j] * Fm[i, j] + (-opt_x[i, j+1] + (1 / 6) * dt[j] * F[i, j+1])
)
println("Generating Hermite residual subexpressions.")
H = @NLexpression(
model,
[i = 1:D, j = 1:N],
((1 / 2) * opt_x[i, j] + (dt[j] / 8) * F[i, j]) + ((1 / 2) * opt_x[i, j+1] - (dt[j] / 8) * F[i, j+1]) - opt_xm[i, j],
)
@NLconstraint(model, [i = 1:D, j = 1:N], S[i, j] == 0.0)
@NLconstraint(model, [i = 1:D, j = 1:N], H[i, j] == 0.0)
for j = 1:D
set_lower_bound.(opt_x[j, :], lb_states[j])
set_lower_bound.(opt_xm[j, :], lb_states[j])
set_upper_bound.(opt_x[j, :], ub_states[j])
set_upper_bound.(opt_xm[j, :], ub_states[j])
end
for i = 1:P
set_lower_bound(opt_p[i], lb_params[i])
set_upper_bound(opt_p[i], ub_params[i])
end
@NLobjective(
model,
Min,
(
sum((opt_x[i, j] - data_x[i, j])^2 for i in observed_dims, j = 1:(N+1)) +
sum((opt_xm[i, j] - data_xm[i, j])^2 for i in observed_dims, j = 1:N) +
sum(opt_k[j]^2 for j = 1:(2*N+1))
) / (2 * N + 1)
)
println("Model generation complete!")
model
end
gtp = [1.0, -54.4, -77.0, 50.0, 0.3, 20.0, 120.0, -40.0, 15.0, 0.1, 0.4, 0.0, -60.0, -15.0, 1.0, 7.0, 0.0, -55.0, 30.0, 1.0, 5.0, 0.0];
function lorenz(u, p, t)
dx = 10.0 * (u[2] - u[1])
dy = u[1] * (28.0 - u[3]) - u[2]
dz = u[1] * u[2] - (8 / 3) * u[3]
SA[dx, dy, dz]
end
function scale_timeseries(ts, minv, maxv)
tsmax = maximum(ts)
tsmin = minimum(ts)
minv .+ (maxv - minv) * (ts .- tsmin) / (tsmax - tsmin)
end
lor_tmax = 80.0
lor_dt = 0.005
lor_u0 = SA[-1.31; 0.8; 19.77]
lor_tspan = (0.0, lor_tmax)
lor_prob = ODEProblem(lorenz, lor_u0, lor_tspan)
lor_sol = solve(lor_prob, RK4(), adaptive = true, dt = lor_dt);
lorfunc = linear_interpolation(scale_timeseries(lor_sol.t, 0.0, 300.0), scale_timeseries(lor_sol[1, :], -20.0, 20.0), extrapolation_bc = 0.0);
const tanhgain = 1.0;
@inline nlss(V, θ, σ) = 0.5 * (1 + tanh(tanhgain * (V - θ) / σ))
@inline nltc(V, θ, σ, τ0, τ1, τ2) = τ0 + τ1 * (1 - tanh(tanhgain * (V - θ) / σ)^2) + τ2 * (1 + tanh(tanhgain * (V - θ) / σ))
function hhvf_oop(u, p, t)
pvec, Iapp = p
Cm, EL, EK, ENa, gL, gK, gNa, θm, σm, τm0, τm1, τm2, θh, σh, τh0, τh1, τh2, θn, σn, τn0, τn1, τn2 = pvec
V, m, h, n = u
dV = (1 / Cm) * (gNa * m * m * m * h * (ENa - V) + gK * n * n * n * n * (EK - V) + gL * (EL - V) + Iapp(t))
dm = (nlss(V, θm, σm) - m) / nltc(V, θm, σm, τm0, τm1, τm2)
dh = (nlss(V, θh, σh) - h) / nltc(V, θh, σh, τh0, τh1, τh2)
dn = (nlss(V, θn, σn) - n) / nltc(V, θn, σn, τn0, τn1, τn2)
[dV, dm, dh, dn]
end
hh_tmax = 300.0
hh_saveat = 0.04
hh_u0 = [
-68.24221681836171
0.056029230048653705
0.7700232861002139
0.3402655968929933
];
hh_prob = ODEProblem(hhvf_oop, hh_u0, (0.0, hh_tmax), (gtp, lorfunc))
hh_sol = solve(hh_prob, RK4(), adaptive = true);
function hhvf_assim_controlled(u, p, f)
Cm, Isa, EL, EK, ENa, gL, gK, gNa, θm, σm, τm0, τm1, τm2, θh, σh, τh0, τh1, τh2, θn, σn, τn0, τn1, τn2 = p
I, data, k = f
V, m, h, n = u
dV = (1 / Cm) * (gNa * m * m * m * h * (ENa - V) + gK * n * n * n * n * (EK - V) + gL * (EL - V) + I / Isa) + k * (data - V)
dm = (nlss(V, θm, σm) - m) / nltc(V, θm, σm, τm0, τm1, τm2)
dh = (nlss(V, θh, σh) - h) / nltc(V, θh, σh, τh0, τh1, τh2)
dn = (nlss(V, θn, σn) - n) / nltc(V, θn, σn, τn0, τn1, τn2)
[dV, dm, dh, dn]
end
plower = [1.0, 0.0005, -100.0, -100.0, 0.0, 0.0, 0.0, 0.0, -100.0, 0.01, 0.01, 0.01, 0.0, -100.0, -100.0, 0.01, 0.01, 0.0, -100.0, 0.01, 0.01, 0.01, 0.0]
pupper = [1.0, 1.0, 0.0, 0.0, 100.0, 100.0, 1000.0, 1500.0, -0.01, 100.0, 10.0, 10.0, 0.0, -0.01, -0.01, 10.0, 10.0, 0.0, -0.01, 100.0, 10.0, 10.0, 0.0]
slower = [-150.0, 0.0, 0.0, 0.0]
supper = [150.0, 1.0, 1.0, 1.0];
vam = VarAssimMeta(hhvf_assim_controlled, 4, 23, [1], [1, 2, 3], slower, supper, plower, pupper)
timepoints = hh_sol.t
N = length(timepoints) - 1
midt = (hh_sol.t[1:end-1] + hh_sol.t[2:end]) / 2.0;
allt = vcat(vec(hcat(hh_sol.t[1:end-1], midt)'), hh_sol.t[end]);
middatamat = hcat(hh_sol.(midt)...);
alldatamat = hcat(hh_sol.(allt)...);
obsdatamat = alldatamat[1:1, :];
drive = lorfunc.(allt)';
noise = 0.1 * randn(2N + 1);
data = obsdatamat + noise';
drive_frozen = vcat(drive, data);
model = gen_dspe_jump_model(vam, SimpsonHermite(), timepoints, data, drive_frozen)
set_optimizer(model, Ipopt.Optimizer)
optimize!(model)