Bridging MTK and JuMP

I am doing some basic thermal modelling in MTK, and I want to use the resulting discretized dynamics in a JuMP model.

I came up with the following PoC:

using Pkg
Pkg.activate(".")

import ModelingToolkit as MTK
using JuMP
using Ipopt
using UnicodePlots
using SeeToDee: Rk4

############################
# 1. ModelingToolkit model
############################

MTK.@parameters t Ta Ri Ro C Cw
D = MTK.Differential(t)

MTK.@variables T(t) Tw(t) u(t)

eqs = [
    C * D(T) ~ (Tw - T) / Ri + u,
    Cw * D(Tw) ~ (Ta - Tw) / Ro - (Tw - T) / Ri,
]

@named sys = MTK.ODESystem(eqs, t)
sys = MTK.mtkcompile(sys; inputs=[u])

############################
# 2. RHS function
############################

states = [T, Tw]
controls = [u]
params = [Ta, Ri, Ro, C, Cw]

rhs = [eq.rhs for eq in MTK.equations(sys)]

f = MTK.build_function(rhs, states, controls, params;
    expression=Val(false))[1]

############################
# 3. Discretization
############################

Δt = 1.0
f_disc = Rk4(f, Δt)

############################
# 4. Parameters
############################

p = [5.0, 1.0, 3.0, 10.0, 30.0]

############################
# 5. Optimization model
############################

N = 24
nx, nu = 2, 1

model = Model(Ipopt.Optimizer)

JuMP.@variables(model,
    begin
        X[1:nx, 1:N]
        U[1:nu, 1:N] >= 0
    end
)

############################
# 6. Initial state
############################

X0 = [18.0, 10.0]

@constraint(model, X[:, 1] .== X0)

############################
# 7. Dynamics
############################

@constraint(model, [k = 1:N-1],
    X[:, k+1] .== f_disc(X[:, k], U[:, k], p, 0.0)
)

############################
# 8. Objective
############################

@objective(model, Min,
    sum((X[1, k] - 21)^2 + 0.01 * U[1, k]^2 for k in 1:N)
)

############################
# 9. Solve
############################

optimize!(model)

############################
# 10. Plot
############################

T_vals = value.(X[1, :])
Tw_vals = value.(X[2, :])
u_vals = vec(value.(U))

println(lineplot(1:N, [T_vals Tw_vals],
    title="Temperature trajectory",
    name=["T air" "T wall"]))

println(lineplot(1:N, u_vals,
    title="Heating power",
    name="u"))

Right now it is a really trivial example that could be discretized by hand, but this MWE will be expanded on later.

I am mainly looking for some feedback about whether this is the right approach. It works great, the results are more or less what I would expect.

ps. I know about Should you use JuMP? · JuMP, but I am stubborn and I want to use JuMP (I have other things that need JuMP in this project).
pps. I also know about GitHub - PSORLab/EOptInterface.jl: An abstraction layer for optimizing equation-oriented/acausal models · GitHub, but it seems a bit overkill for my usecase.

edit: Just realized this is correctly inferred by JuMP to be a QP! So HiGHS can also be used :slight_smile: