Bridging MTK and JuMP

I am doing some basic thermal modelling in MTK, and I want to use the resulting discretized dynamics in a JuMP model.

I came up with the following PoC:

using Pkg
Pkg.activate(".")

import ModelingToolkit as MTK
using JuMP
using Ipopt
using UnicodePlots
using SeeToDee: Rk4

############################
# 1. ModelingToolkit model
############################

MTK.@parameters t Ta Ri Ro C Cw
D = MTK.Differential(t)

MTK.@variables T(t) Tw(t) u(t)

eqs = [
    C * D(T) ~ (Tw - T) / Ri + u,
    Cw * D(Tw) ~ (Ta - Tw) / Ro - (Tw - T) / Ri,
]

@named sys = MTK.ODESystem(eqs, t)
sys = MTK.mtkcompile(sys; inputs=[u])

############################
# 2. RHS function
############################

states = [T, Tw]
controls = [u]
params = [Ta, Ri, Ro, C, Cw]

rhs = [eq.rhs for eq in MTK.equations(sys)]

f = MTK.build_function(rhs, states, controls, params;
    expression=Val(false))[1]

############################
# 3. Discretization
############################

Δt = 1.0
f_disc = Rk4(f, Δt)

############################
# 4. Parameters
############################

p = [5.0, 1.0, 3.0, 10.0, 30.0]

############################
# 5. Optimization model
############################

N = 24
nx, nu = 2, 1

model = Model(Ipopt.Optimizer)

JuMP.@variables(model,
    begin
        X[1:nx, 1:N]
        U[1:nu, 1:N] >= 0
    end
)

############################
# 6. Initial state
############################

X0 = [18.0, 10.0]

@constraint(model, X[:, 1] .== X0)

############################
# 7. Dynamics
############################

@constraint(model, [k = 1:N-1],
    X[:, k+1] .== f_disc(X[:, k], U[:, k], p, 0.0)
)

############################
# 8. Objective
############################

@objective(model, Min,
    sum((X[1, k] - 21)^2 + 0.01 * U[1, k]^2 for k in 1:N)
)

############################
# 9. Solve
############################

optimize!(model)

############################
# 10. Plot
############################

T_vals = value.(X[1, :])
Tw_vals = value.(X[2, :])
u_vals = vec(value.(U))

println(lineplot(1:N, [T_vals Tw_vals],
    title="Temperature trajectory",
    name=["T air" "T wall"]))

println(lineplot(1:N, u_vals,
    title="Heating power",
    name="u"))

Right now it is a really trivial example that could be discretized by hand, but this MWE will be expanded on later.

I am mainly looking for some feedback about whether this is the right approach. It works great, the results are more or less what I would expect.

ps. I know about Should you use JuMP? · JuMP, but I am stubborn and I want to use JuMP (I have other things that need JuMP in this project).
pps. I also know about GitHub - PSORLab/EOptInterface.jl: An abstraction layer for optimizing equation-oriented/acausal models · GitHub, but it seems a bit overkill for my usecase.

edit: Just realized this is correctly inferred by JuMP to be a QP! So HiGHS can also be used :slight_smile:

See this example that uses some built in functions to generate the f function

You can also have a look at the MTK docs that detail how this works

MTK also has built in support to target JuMP for optimization, but the docs are not fantastic at the moment

Thanks a lot, generate_control_function seems to be the way to go then.

Here’s the new script I am working on:

using Pkg
Pkg.activate(".")

import ModelingToolkit as MTK
using JuMP
using HiGHS
using UnicodePlots
import SeeToDee as c2d

############################
# 1. ModelingToolkit model
############################

MTK.@parameters t Ta Ri Ro C Cw
D = MTK.Differential(t)

MTK.@variables T(t) Tw(t) u(t)

eqs = [
    C * D(T) ~ (Tw - T) / Ri + u,
    Cw * D(Tw) ~ (Ta - Tw) / Ro - (Tw - T) / Ri,
]

MTK.@named sys = MTK.ODESystem(eqs, t)

############################
# 2. RHS function
############################

# generate_control_function handles mtkcompile internally — do not call it beforehand.
# f_oop has signature (x, u, p, t) matching SeeToDee.Rk4's expected form.
(f_oop, _), x_sym, p_sym, io_sys = MTK.generate_control_function(sys, [u]; simplify=true)

# Build parameter vector in MTK's chosen p_sym ordering.
p = MTK.varmap_to_vars(Dict(Ta => 5.0, Ri => 1.0, Ro => 3.0, C => 10.0, Cw => 30.0), p_sym)

############################
# 3. Discretization
############################

Δt = 0.25
f_disc = c2d.Rk4(f_oop, Δt)

############################
# 4. Problem dimensions
############################

N = 48
nx = length(x_sym)
nu = 1

# State indices — robust to MTK's chosen x_sym ordering.
T_idx = findfirst(s -> isequal(s, T), x_sym)
Tw_idx = findfirst(s -> isequal(s, Tw), x_sym)

# Initial state vector in MTK's x_sym ordering.
X0 = MTK.varmap_to_vars(Dict(T => 18.0, Tw => 10.0), x_sym)

############################
# 5. Optimization model
############################

model = Model(HiGHS.Optimizer)

JuMP.@variables(model,
    begin
        X[1:nx, 1:N]
        U[1:nu, 1:N] >= 0
    end
)

############################
# 6. Constraints
############################

@constraint(model, X[:, 1] .== X0)

@constraint(model, [k = 1:N-1],
    X[:, k+1] .== f_disc(X[:, k], U[:, k], p, 0.0)
)

############################
# 7. Objective
############################

@objective(model, Min,
    sum((X[T_idx, k] - 21)^2 + 0.01 * U[1, k]^2 for k in 1:N)
)

############################
# 8. Solve
############################

@time optimize!(model)

############################
# 9. Plot
############################

T_vals = value.(X[T_idx, :])
Tw_vals = value.(X[Tw_idx, :])
u_vals = vec(value.(U))

println(lineplot(1:N, [T_vals Tw_vals],
    title="Temperature trajectory",
    name=["T air" "T wall"]))

println(lineplot(1:N, u_vals,
    title="Heating power",
    name="u"))

Would appreciate some feedback :slight_smile:

One thing I am not entirely sure about is how to index the variables. I do:

T_idx = findfirst(s -> isequal(s, T), x_sym)

Is there a better way to do this mapping?

Finally, I am quit confused that this part ‘just works’

Δt = 0.25
f_disc = c2d.Rk4(f_oop, Δt)

@constraint(model, [k = 1:N-1],
    X[:, k+1] .== f_disc(X[:, k], U[:, k], p, 0.0)
)

Is this just JuMP being very good at tracing things? How does it know how to handle this SeeToDee.Rk4{ModelingToolkitBase.GeneratedFunctionWrapper?

This is related to this cool talk by @ChrisRackauckas right? https://youtu.be/B9ymnQOQY3s :slight_smile:

that’s fine.

yes, SeeToDee.Rk4{ModelingToolkitBase.GeneratedFunctionWrapper is a functor so it behaves like any other function.

If you want to see a more complex example use case of these input-output functions, see Home · LowLevelParticleFiltersMTK Documentation its implementation uses generate_control_function and handles DAEs etc.