Bridging MTK and JuMP

I am doing some basic thermal modelling in MTK, and I want to use the resulting discretized dynamics in a JuMP model.

I came up with the following PoC:

using Pkg
Pkg.activate(".")

import ModelingToolkit as MTK
using JuMP
using Ipopt
using UnicodePlots
using SeeToDee: Rk4

############################
# 1. ModelingToolkit model
############################

MTK.@parameters t Ta Ri Ro C Cw
D = MTK.Differential(t)

MTK.@variables T(t) Tw(t) u(t)

eqs = [
    C * D(T) ~ (Tw - T) / Ri + u,
    Cw * D(Tw) ~ (Ta - Tw) / Ro - (Tw - T) / Ri,
]

@named sys = MTK.ODESystem(eqs, t)
sys = MTK.mtkcompile(sys; inputs=[u])

############################
# 2. RHS function
############################

states = [T, Tw]
controls = [u]
params = [Ta, Ri, Ro, C, Cw]

rhs = [eq.rhs for eq in MTK.equations(sys)]

f = MTK.build_function(rhs, states, controls, params;
    expression=Val(false))[1]

############################
# 3. Discretization
############################

Δt = 1.0
f_disc = Rk4(f, Δt)

############################
# 4. Parameters
############################

p = [5.0, 1.0, 3.0, 10.0, 30.0]

############################
# 5. Optimization model
############################

N = 24
nx, nu = 2, 1

model = Model(Ipopt.Optimizer)

JuMP.@variables(model,
    begin
        X[1:nx, 1:N]
        U[1:nu, 1:N] >= 0
    end
)

############################
# 6. Initial state
############################

X0 = [18.0, 10.0]

@constraint(model, X[:, 1] .== X0)

############################
# 7. Dynamics
############################

@constraint(model, [k = 1:N-1],
    X[:, k+1] .== f_disc(X[:, k], U[:, k], p, 0.0)
)

############################
# 8. Objective
############################

@objective(model, Min,
    sum((X[1, k] - 21)^2 + 0.01 * U[1, k]^2 for k in 1:N)
)

############################
# 9. Solve
############################

optimize!(model)

############################
# 10. Plot
############################

T_vals = value.(X[1, :])
Tw_vals = value.(X[2, :])
u_vals = vec(value.(U))

println(lineplot(1:N, [T_vals Tw_vals],
    title="Temperature trajectory",
    name=["T air" "T wall"]))

println(lineplot(1:N, u_vals,
    title="Heating power",
    name="u"))

Right now it is a really trivial example that could be discretized by hand, but this MWE will be expanded on later.

I am mainly looking for some feedback about whether this is the right approach. It works great, the results are more or less what I would expect.

ps. I know about Should you use JuMP? · JuMP, but I am stubborn and I want to use JuMP (I have other things that need JuMP in this project).
pps. I also know about GitHub - PSORLab/EOptInterface.jl: An abstraction layer for optimizing equation-oriented/acausal models · GitHub, but it seems a bit overkill for my usecase.

edit: Just realized this is correctly inferred by JuMP to be a QP! So HiGHS can also be used :slight_smile:

1 Like

See this example that uses some built in functions to generate the f function

You can also have a look at the MTK docs that detail how this works

MTK also has built in support to target JuMP for optimization, but the docs are not fantastic at the moment

4 Likes

Thanks a lot, generate_control_function seems to be the way to go then.

Here’s the new script I am working on:

using Pkg
Pkg.activate(".")

import ModelingToolkit as MTK
using JuMP
using HiGHS
using UnicodePlots
import SeeToDee as c2d

############################
# 1. ModelingToolkit model
############################

MTK.@parameters t Ta Ri Ro C Cw
D = MTK.Differential(t)

MTK.@variables T(t) Tw(t) u(t)

eqs = [
    C * D(T) ~ (Tw - T) / Ri + u,
    Cw * D(Tw) ~ (Ta - Tw) / Ro - (Tw - T) / Ri,
]

MTK.@named sys = MTK.ODESystem(eqs, t)

############################
# 2. RHS function
############################

# generate_control_function handles mtkcompile internally — do not call it beforehand.
# f_oop has signature (x, u, p, t) matching SeeToDee.Rk4's expected form.
(f_oop, _), x_sym, p_sym, io_sys = MTK.generate_control_function(sys, [u]; simplify=true)

# Build parameter vector in MTK's chosen p_sym ordering.
p = MTK.varmap_to_vars(Dict(Ta => 5.0, Ri => 1.0, Ro => 3.0, C => 10.0, Cw => 30.0), p_sym)

############################
# 3. Discretization
############################

Δt = 0.25
f_disc = c2d.Rk4(f_oop, Δt)

############################
# 4. Problem dimensions
############################

N = 48
nx = length(x_sym)
nu = 1

# State indices — robust to MTK's chosen x_sym ordering.
T_idx = findfirst(s -> isequal(s, T), x_sym)
Tw_idx = findfirst(s -> isequal(s, Tw), x_sym)

# Initial state vector in MTK's x_sym ordering.
X0 = MTK.varmap_to_vars(Dict(T => 18.0, Tw => 10.0), x_sym)

############################
# 5. Optimization model
############################

model = Model(HiGHS.Optimizer)

JuMP.@variables(model,
    begin
        X[1:nx, 1:N]
        U[1:nu, 1:N] >= 0
    end
)

############################
# 6. Constraints
############################

@constraint(model, X[:, 1] .== X0)

@constraint(model, [k = 1:N-1],
    X[:, k+1] .== f_disc(X[:, k], U[:, k], p, 0.0)
)

############################
# 7. Objective
############################

@objective(model, Min,
    sum((X[T_idx, k] - 21)^2 + 0.01 * U[1, k]^2 for k in 1:N)
)

############################
# 8. Solve
############################

@time optimize!(model)

############################
# 9. Plot
############################

T_vals = value.(X[T_idx, :])
Tw_vals = value.(X[Tw_idx, :])
u_vals = vec(value.(U))

println(lineplot(1:N, [T_vals Tw_vals],
    title="Temperature trajectory",
    name=["T air" "T wall"]))

println(lineplot(1:N, u_vals,
    title="Heating power",
    name="u"))

Would appreciate some feedback :slight_smile:

One thing I am not entirely sure about is how to index the variables. I do:

T_idx = findfirst(s -> isequal(s, T), x_sym)

Is there a better way to do this mapping?

Finally, I am quit confused that this part ‘just works’

Δt = 0.25
f_disc = c2d.Rk4(f_oop, Δt)

@constraint(model, [k = 1:N-1],
    X[:, k+1] .== f_disc(X[:, k], U[:, k], p, 0.0)
)

Is this just JuMP being very good at tracing things? How does it know how to handle this SeeToDee.Rk4{ModelingToolkitBase.GeneratedFunctionWrapper?

This is related to this cool talk by @ChrisRackauckas right? https://youtu.be/B9ymnQOQY3s :slight_smile:

that’s fine.

yes, SeeToDee.Rk4{ModelingToolkitBase.GeneratedFunctionWrapper is a functor so it behaves like any other function.

1 Like

If you want to see a more complex example use case of these input-output functions, see Home · LowLevelParticleFiltersMTK Documentation its implementation uses generate_control_function and handles DAEs etc.

1 Like