ModelingToolkit: using intermediate calculation results multiple times

Let’s say I have an expensive scalar that shows up multiple times in my set of differential equations. Is there a way to tell ModelingToolkit to hoist that calculation out and only do it once? If I use an intermediate calculation it seems to just get copied everywhere it is used. E.g. using cosine as an dummy expensive calculation:

@parameters ω, t
@variables ψ₁(t), ψ₂(t)
@derivatives D'~t

expensive_f = cos(ω*t)

eqs = [D(ψ₁) ~ expensive_f*ψ₂,
       D(ψ₂) ~ expensive_f*ψ₁]

de = ODESystem(eqs)
generate_function(de)[2]

:((var"##MTIIPVar#335", var"##MTKArg#331", var"##MTKArg#332", var"##MTKArg#333")->begin
          @inbounds begin
                  let (ψ₁, ψ₂, ω, t) = (var"##MTKArg#331"[1], var"##MTKArg#331"[2], var"##MTKArg#332"[1], var"##MTKArg#333")
                      var"##MTIIPVar#335"[1] = cos(ω * t) * ψ₂
                      var"##MTIIPVar#335"[2] = cos(ω * t) * ψ₁
                  end
              end
          nothing
      end)

Looking through the @code_llvm there are still two cos calls and but by @code_native there is only one. Perhaps the compiler is always clever enough to pick this up…

That’s common subexpression elimination and indeed the compiler will know to do this once.

BTW, we will have a way to force CSE in the near future though (via “output variables”)

That sounds promising. Poking around it seems that even changing expensive_f to a complex exponential the compiler doesn’t do CSE and call exp twice.

Could you give an MWE? We can use that to figure out how to get it fixed in the compiler.

Of course. Here’s a notebook with some of the expressions printed out but an example of the gap is summarized with:

# repeat with 100 terms to see a much bigger gap

N = 100
@parameters ω, t
@variables ψ[1:N](t)
@derivatives D'~t

drive_amplitude = cos(ω*t)

eqs = D.(ψ) .~ drive_amplitude.*ψ

de = ODESystem(eqs)

println("Without manual CSE")
u = collect(range(0,1; length=N)); du = similar(u);
f = eval(generate_function(de)[2])
@btime f($du, $u, 5e9, 2.0)

println("\n With manual CSE")
ex = generate_function(de)[2]
ex = postwalk(x -> x == :(cos(ω*t)) ? :(drive_amplitude) : x, ex)
pushfirst!(ex.args[2].args[1].args[3].args[1].args[2].args, :(drive_amplitude = cos(ω*t)))
f = eval(ex)
@btime f($du, $u, 2π*5, 2.0)

Without manual CSE
  3.385 μs (0 allocations: 0 bytes)

 With manual CSE
  80.211 ns (0 allocations: 0 bytes)

I think it’s because it cannot prove cos is pure? @sdanisch might recall the issue.

I guess it’s not about figuring out the purity (which seems to work), but actually forwarding that information to the LICM optimization:

I guess we have a PR for this?