How to improve structural_simplify performance?

Modelingtoolkits structural_simplify can be very slow on bigger systems. What can I do to improve it’s performance? Should I write my equations in a certain way, should I split up big equations into smaller ones, or should I use certain keyword arguments?

I guess we need an MWE to be able to comment on this question.

Here is my MWE. I used profileview on structural_simplify and saw that a lot of time is spent in alias elimination: ModelingToolkit/ZOG3I/src/systems/alias_elimination.jl:46, alias_elimination!

MWE
# MWE of tether system creation and simplification
using ModelingToolkit, LinearAlgebra
using ModelingToolkit: t_nounits as t, D_nounits as D

struct Point
    type::Symbol
    position::Union{Vector{Float64}, Nothing}
    velocity::Union{Vector{Float64}, Nothing}
    mass::Union{Float64, Nothing}
    force::Vector{Float64}
end

struct Segment
    points::Tuple{Int, Int}
    l0::Union{Float64, Nothing}
    stiffness::Float64
    damping::Float64
end

struct Pulley
    segments::Tuple{Int, Int}
    sum_length::Float64
end

function main()
    """
    Add or remove points, segments and pulleys to these lists to configure your system
    """
    # protune-speed-system.jpg
    points = [
        Point(:fixed,  [0, 0, 0],  zeros(3), nothing, zeros(3)),  # Fixed point
        Point(:quasi_static, [-1, 0, 0], zeros(3), 1.0, zeros(3)),
        Point(:quasi_static, [-2, 0, 0], zeros(3), 1.0, zeros(3)),
        Point(:quasi_static, [-3, 0, 0], zeros(3), 1.0, zeros(3)),
        Point(:dynamic, [-4, 0, 0], zeros(3), 100.0, [0., 0., 0.]),
    ]

    stiffness = 614600
    damping = 4730
    segments = [
        Segment((1, 2), norm(points[1].position - points[2].position), stiffness, damping),
        Segment((2, 3), norm(points[2].position - points[3].position), stiffness, damping),
        Segment((3, 4), norm(points[3].position - points[4].position), stiffness, damping),
        Segment((4, 5), norm(points[4].position - points[5].position), stiffness, damping),
    ]

    pulleys = []

    g_earth::Vector{Float64} = [0.0, 0.0, -9.81] # gravitational acceleration     [m/s²]
    l0 = 50                                      # initial tether length             [m]
    c_spring = 614600                            # unit spring constant              [N]
    rel_compression_stiffness = 0.01             # relative compression stiffness    [-]
    damping = 473                                # unit damping constant            [Ns]

    function calc_initial_state(points, segments, pulleys)
        POS0 = zeros(3, length(points))
        VEL0 = zeros(3, length(points))
        L0 = zeros(length(pulleys))
        V0 = zeros(length(pulleys))
        for i in eachindex(points)
            POS0[:, i] .= points[i].position
            VEL0[:, i] .= points[i].velocity
        end
        for i in eachindex(pulleys)
            L0[i] = segments[pulleys[i].segments[1]].l0
            V0[i] = 0.0
        end
        POS0, VEL0, L0, V0
    end

    function calc_spring_forces(pos::AbstractMatrix{T}, vel, pulley_l0) where T
        # loop over all segments to calculate spring forces
        spring_force = zeros(T, length(segments))
        spring_force_vec = zeros(T, 3, length(segments))
        segment = zeros(T, 3)
        unit_vector = zeros(T, 3)
        rel_vel = zeros(T, 3)
        for (tether_idx, tether) in enumerate(segments)
            found = false
            for (pulley_idx, pulley) in enumerate(pulleys)
                if tether_idx == pulley.segments[1] # each tether should only be part of one pulley
                    l0 = pulley_l0[pulley_idx]
                    found = true
                    break
                elseif tether_idx == pulley.segments[2]
                    l0 = pulley.sum_length - pulley_l0[pulley_idx]
                    found = true
                    break
                end
            end
            if !found
                l0 = tether.l0
            end
            p1, p2 = tether.points[1], tether.points[2]

            segment         .= pos[:, p2] .- pos[:, p1]
            len              = norm(segment)
            unit_vector     .= segment ./ len
            rel_vel         .= vel[:, p1] .- vel[:, p2]
            spring_vel       = rel_vel ⋅ unit_vector
            spring_force[tether_idx]          = (tether.stiffness * tether.l0 * (len - l0) - tether.damping * tether.l0 * spring_vel)
            spring_force_vec[:, tether_idx]  .= spring_force[tether_idx] .* unit_vector
        end
        return spring_force_vec, spring_force
    end

    function calc_acc(pos::AbstractMatrix{T}, vel, pulley_l0) where T
        spring_force_vec, spring_force = calc_spring_forces(pos, vel, pulley_l0)
        
        pulley_acc = zeros(T, length(pulleys))
        for (pulley_idx, pulley) in enumerate(pulleys)
            M = 3.1
            pulley_force = spring_force[pulley.segments[1]] - spring_force[pulley.segments[2]]
            pulley_acc[pulley_idx] = pulley_force / M
        end

        acc = zeros(T, 3, length(points))
        force = zeros(T, 3)
        for (point_idx, point) in enumerate(points)
            if point.type === :fixed
                acc[:, point_idx] .= 0.0
            else
                force .= 0.0
                for (j, tether) in enumerate(segments)
                    if point_idx in tether.points
                        inverted = tether.points[2] == point_idx
                        if inverted
                            force .-= spring_force_vec[:, j]
                        else
                            force .+= spring_force_vec[:, j]
                        end
                    end
                end
                force .+= point.force
                acc[:, point_idx] .= force ./ point.mass .+ g_earth
            end
        end
        return acc, pulley_acc
    end

    function model()
        @parameters rel_compression_stiffness = rel_compression_stiffness
        @variables begin
            pos(t)[1:3, eachindex(points)]
            vel(t)[1:3, eachindex(points)]
            acc(t)[1:3, eachindex(points)]
            force(t)[1:3, eachindex(points)]

            pulley_force(t)[eachindex(pulleys)]
            pulley_l0(t)[eachindex(pulleys)] # first tether length in pulley
            pulley_vel(t)[eachindex(pulleys)]
            pulley_acc(t)[eachindex(pulleys)]

            segment_vec(t)[1:3, eachindex(segments)]
            unit_vector(t)[1:3, eachindex(segments)]
            l_spring(t), c_spring(t), damping(t), m_tether_particle(t)
            len(t)[eachindex(segments)]
            l0(t)[eachindex(segments)]
            rel_vel(t)[1:3, eachindex(segments)]
            spring_vel(t)[eachindex(segments)]
            spring_force(t)[eachindex(segments)]
            spring_force_vec(t)[1:3, eachindex(segments)] # spring force from spring p1 to spring p2
        end

        POS0, VEL0, L0, V0 = calc_initial_state(points, segments, pulleys)

        defaults = Pair{Num, Real}[]
        guesses = Pair{Num, Real}[]
        eqs = [
            D(pulley_l0) ~ pulley_vel
            D(pulley_vel) ~ pulley_acc - 10pulley_vel
        ]

        for (point_idx, point) in enumerate(points)
            if point.type === :fixed
                eqs = [
                    eqs
                    pos[:, point_idx] ~ point.position
                    vel[:, point_idx] ~ zeros(3)
                ]
            elseif point.type === :dynamic
                eqs = [
                    eqs
                    D(pos[:, point_idx]) ~ vel[:, point_idx]
                    D(vel[:, point_idx]) ~ acc[:, point_idx]
                ]
                defaults = [
                    defaults
                    [pos[j, point_idx] => POS0[j, point_idx] for j in 1:3]
                    [vel[j, point_idx] => 0 for j in 1:3]
                ]
            elseif point.type === :quasi_static
                eqs = [
                    eqs
                    acc[:, point_idx] ~ zeros(3)
                    vel[:, point_idx] ~ zeros(3)
                ]
                guesses = [
                    guesses
                    [pos[j, point_idx] => POS0[j, point_idx] for j in 1:3]
                    [vel[j, point_idx] => 0 for j in 1:3]
                ]
            else
                println("wrong type")
            end
        end

        defaults = [
            defaults
            pulley_l0 => L0
            pulley_vel => V0
        ]

        eqs = [
            eqs
            vec(acc) .~ vec(calc_acc(pos, vel, pulley_l0)[1])
            vec(pulley_acc) .~ vec(calc_acc(pos, vel, pulley_l0)[2])
        ]
        
        eqs = reduce(vcat, Symbolics.scalarize.(eqs))
        @named sys = ODESystem(eqs, t)
        @time sys = structural_simplify(sys)
        sys, pos, vel, defaults, guesses
    end
    model()
end

sys, pos, vel, defaults, guesses = main()
nothing

Long equations are killing structural_simplify performance. This becomes clear from the revised MWE, where I generate more short equations instead of just a few very big ones. Structural_simplify is almost 10x faster now.

Revised MWE with more shorter equations
# MWE of tether system creation and simplification
using ModelingToolkit, LinearAlgebra
using ModelingToolkit: t_nounits as t, D_nounits as D

struct Point
    type::Symbol
    position::Union{Vector{Float64}, Nothing}
    velocity::Union{Vector{Float64}, Nothing}
    mass::Union{Float64, Nothing}
    force::Vector{Float64}
end

struct Segment
    points::Tuple{Int, Int}
    l0::Union{Float64, Nothing}
    stiffness::Float64
    damping::Float64
end

struct Pulley
    segments::Tuple{Int, Int}
    sum_length::Float64
end

function main()
    """
    Add or remove points, segments and pulleys to these lists to configure your system
    """
    # protune-speed-system.jpg
    points = [
        Point(:fixed,  [0, 0, 0],  zeros(3), nothing, zeros(3)),  # Fixed point
        Point(:quasi_static, [-1, 0, 0], zeros(3), 1.0, zeros(3)),
        Point(:quasi_static, [-2, 0, 0], zeros(3), 1.0, zeros(3)),
        Point(:quasi_static, [-3, 0, 0], zeros(3), 1.0, zeros(3)),
        Point(:dynamic, [-4, 0, 0], zeros(3), 100.0, [0., 0., 0.]),
    ]

    stiffness = 614600
    damping = 4730
    segments = [
        Segment((1, 2), norm(points[1].position - points[2].position), stiffness, damping),
        Segment((2, 3), norm(points[2].position - points[3].position), stiffness, damping),
        Segment((3, 4), norm(points[3].position - points[4].position), stiffness, damping),
        Segment((4, 5), norm(points[4].position - points[5].position), stiffness, damping),
    ]

    pulleys = []

    g_earth::Vector{Float64} = [0.0, 0.0, -9.81] # gravitational acceleration     [m/s²]
    l0 = 50                                      # initial tether length             [m]
    c_spring = 614600                            # unit spring constant              [N]
    rel_compression_stiffness = 0.01             # relative compression stiffness    [-]
    damping = 473                                # unit damping constant            [Ns]

    function calc_initial_state(points, segments, pulleys)
        POS0 = zeros(3, length(points))
        VEL0 = zeros(3, length(points))
        L0 = zeros(length(pulleys))
        V0 = zeros(length(pulleys))
        for i in eachindex(points)
            POS0[:, i] .= points[i].position
            VEL0[:, i] .= points[i].velocity
        end
        for i in eachindex(pulleys)
            L0[i] = segments[pulleys[i].segments[1]].l0
            V0[i] = 0.0
        end
        POS0, VEL0, L0, V0
    end

    function calc_spring_forces(eqs, pos, vel, pulley_l0)
        # loop over all segments to calculate spring forces
        @variables begin
            spring_force(t)[eachindex(segments)]
            spring_force_vec(t)[1:3, eachindex(segments)]
            segment(t)[1:3, eachindex(segments)]
            unit_vector(t)[1:3, eachindex(segments)]
            rel_vel(t)[1:3, eachindex(segments)]
            len(t)[eachindex(segments)]
            spring_vel(t)[eachindex(segments)]
        end
        for (segment_idx, tether) in enumerate(segments)
            found = false
            for (pulley_idx, pulley) in enumerate(pulleys)
                if segment_idx == pulley.segments[1] # each tether should only be part of one pulley
                    l0 = pulley_l0[pulley_idx]
                    found = true
                    break
                elseif segment_idx == pulley.segments[2]
                    l0 = pulley.sum_length - pulley_l0[pulley_idx]
                    found = true
                    break
                end
            end
            if !found
                l0 = tether.l0
            end
            p1, p2 = tether.points[1], tether.points[2]

            eqs = [
                eqs
                segment[:, segment_idx] ~ pos[:, p2] - pos[:, p1]
                len[segment_idx]        ~ norm(segment[:, segment_idx])
                unit_vector[:, segment_idx] ~ segment[:, segment_idx] / len[segment_idx]
                rel_vel[:, segment_idx]     ~ vel[:, p1] .- vel[:, p2]
                spring_vel[segment_idx]     ~ rel_vel[:, segment_idx] ⋅ unit_vector[:, segment_idx]
                spring_force[segment_idx]   ~ (tether.stiffness * tether.l0 * (len[segment_idx] - l0) - tether.damping * tether.l0 * spring_vel[segment_idx])
                spring_force_vec[:, segment_idx]  ~ spring_force[segment_idx] * unit_vector[:, segment_idx]
            ]
        end
        return eqs, spring_force_vec, spring_force
    end

    function calc_acc(eqs, pos, vel, pulley_l0)
        eqs, spring_force_vec, spring_force = calc_spring_forces(eqs, pos, vel, pulley_l0)
        
        @variables pulley_acc(t)[eachindex(pulleys)]
        @variables pulley_force(t)[eachindex(pulleys)]
        for (pulley_idx, pulley) in enumerate(pulleys)
            M = 3.1
            eqs = [
                eqs
                pulley_force[pulley_idx] ~ spring_force[pulley.segments[1]] - spring_force[pulley.segments[2]]
                pulley_acc[pulley_idx] ~ pulley_force[pulley_force] / M
            ]
        end

        @variables acc(t)[1:3, eachindex(points)]
        @variables force(t)[1:3, eachindex(points)]
        for (point_idx, point) in enumerate(points)
            if point.type === :fixed
                eqs = [
                    eqs
                    acc[:, point_idx] ~ zeros(3)
                ]
            else
                F = zeros(Num, 3)
                for (j, tether) in enumerate(segments)
                    if point_idx in tether.points
                        inverted = tether.points[2] == point_idx
                        if inverted
                            F .-= spring_force_vec[:, j]
                        else
                            F .+= spring_force_vec[:, j]
                        end
                    end
                end
                eqs = [
                    eqs
                    force[:, point_idx] ~ F
                    acc[:, point_idx] ~ force[:, point_idx] / point.mass + g_earth
                ]
            end
        end
        return eqs, acc, pulley_acc
    end

    function model()
        @parameters rel_compression_stiffness = rel_compression_stiffness
        @variables begin
            pos(t)[1:3, eachindex(points)]
            vel(t)[1:3, eachindex(points)]
            acc(t)[1:3, eachindex(points)]
            force(t)[1:3, eachindex(points)]

            pulley_force(t)[eachindex(pulleys)]
            pulley_l0(t)[eachindex(pulleys)] # first tether length in pulley
            pulley_vel(t)[eachindex(pulleys)]
            pulley_acc(t)[eachindex(pulleys)]
        end

        POS0, VEL0, L0, V0 = calc_initial_state(points, segments, pulleys)

        defaults = Pair{Num, Real}[]
        guesses = Pair{Num, Real}[]
        eqs = [
            D(pulley_l0) ~ pulley_vel
            D(pulley_vel) ~ pulley_acc - 10pulley_vel
        ]

        for (point_idx, point) in enumerate(points)
            if point.type === :fixed
                eqs = [
                    eqs
                    pos[:, point_idx] ~ point.position
                    vel[:, point_idx] ~ zeros(3)
                ]
            elseif point.type === :dynamic
                eqs = [
                    eqs
                    D(pos[:, point_idx]) ~ vel[:, point_idx]
                    D(vel[:, point_idx]) ~ acc[:, point_idx]
                ]
                defaults = [
                    defaults
                    [pos[j, point_idx] => POS0[j, point_idx] for j in 1:3]
                    [vel[j, point_idx] => 0 for j in 1:3]
                ]
            elseif point.type === :quasi_static
                eqs = [
                    eqs
                    acc[:, point_idx] ~ zeros(3)
                    vel[:, point_idx] ~ zeros(3)
                ]
                guesses = [
                    guesses
                    [pos[j, point_idx] => POS0[j, point_idx] for j in 1:3]
                    [vel[j, point_idx] => 0 for j in 1:3]
                ]
            else
                println("wrong type")
            end
        end

        defaults = [
            defaults
            pulley_l0 => L0
            pulley_vel => V0
        ]

        eqs, acc, pulley_acc = calc_acc(eqs, pos, vel, pulley_l0)
        
        eqs = reduce(vcat, Symbolics.scalarize.(eqs))
        @named sys = ODESystem(eqs, t)
        @time sys = structural_simplify(sys)
        sys, pos, vel, defaults, guesses
    end
    model()
end

sys, pos, vel, defaults, guesses = main()
nothing

1 Like

But still, on the bigger problem (so not the MWE), the performance of structural_simplify is bad: 533 seconds for a system with 191 equations. @ChrisRackauckas are there any plans to improve this?

julia> @time sys = structural_simplify(sys)
533.625644 seconds (464.60 M allocations: 18.985 GiB, 0.35% gc time, 1.52% compilation time: 36% of which was recompilation)
julia> sys
Model sys:
Equations (191):
  191 standard: see equations(sys)
Unknowns (191): see unknowns(sys)
  (pos(t))[1, 9]
  (pos(t))[2, 9]
  (pos(t))[3, 9]
  (vel(t))[1, 9]
  (vel(t))[2, 9]
  ⋮
Parameters (13): see parameters(sys)
  aero_kite_moment_b [defaults to [0.0346076, 31.6809, 0.0212127]]
  moment_dist [defaults to [0.0350009, 0.039114, 0.0231755, 0.00880667, -0.0116542, -0.0407203, -0.0507423, -0.0718568, -0.13384, -0.0962616  …  -0.127448, -0.101881, -0.0725175, -0.0515411, -0.0423484, -0.0135425, 0.00530942, 0.0204002, 0.0382887, 0.0347389]]
  set_values(t) [defaults to [0.0, 0.0, 0.0]]
  aero_kite_force_b [defaults to [-10.4389, -0.126359, 108.13]]
  stiffness_frac [defaults to 0.1]
  ⋮
Observed (2496): see observed(sys)
julia> length(sys.eqs)
2687

Any comments from one of the MTK developers?

I was running the two examples of Bart and got the following results:

Results of second run:

include("bench1.jl")
2.280316 seconds (22.52 M allocations: 923.465 MiB, 4.77% gc time)
Simplifying 45 eqs to 15 eqs

include("bench2.jl")
0.159549 seconds (1.31 M allocations: 53.773 MiB, 12.25% gc time)
Simplifying 117 eqs to 15 eqs

This is just the time for executing structural_simplify() on a Ryzen 7950X CPU on Linux. Using one or 16 threads did not make any difference.

It is indeed strange that simplifying a system of 117 equations is much faster than simplifying a system of 45 equations. Perhaps caching plays a role here, because this in both cases the timing of the second run?

Some improvements are coming. Others will require the sequel to JuliaSimCompiler.

1 Like

I thought I give JuliaSimCompiler.jl a try, but it fails to pre-compile. This seams to be a known bug that is not fixed since months: Can't precompile in Julia v1.11.1 · Issue #6 · JuliaComputing/JuliaSimIssues · GitHub

Strange enough, when using Julia 1.10 the pre-compilation also fails, but with a different error.

The sequel is a different repo

Is the sequel publicly available? If so, where can I find it?

Not yet. So we are in the middle of a transition where the older JuliaSimCompiler needs to stay on older LLVM/MTK versions if you’re going to use it, but the new stuff currently needs an unreleased version of Julia. We’re working through getting this going… but also there’s other performance improvements coming to Symbolics and MTK as well, so there’s just lots of performance stuff going on and I don’t want to give a date to any of it because it’s more like there’s 40 things to do and each will be big or small depending on the kind of code, so it’s impossible to tell when your code will do better.

2 Likes

Exciting, looking forward to try these improvements!

Which version of Julia/ MTK is the latest version that works?