Detecting Function Composition in Symbolics.jl

Motivation

I have a system of differential equations
\dot{U} = F(U,w)
where U\in \mathbb{R}^D are dependent variables and w \in\mathbb{R}^J are parameters. And I can quickly build (in-place) Julia functions from Symbolics.jl and ModelingToolkit.jl of the following functions of U and w:
F(U,w), \; \nabla_uF(U,w), \;\nabla_wF(U,W), and \nabla_w[\nabla_uF](U,w)
I then have data \{u^{(m)}\}_{m=1}^M, where M\in \mathbb{N} is large. I don’t want assume anything about F a priori, but I need to evaluate all four of these functions at every u^{(m)} for different w during an optimization routine. This is quickly becoming expensive…

Goal

I would like to further speed up evaluating these functions by auto detecting
\nabla_uF[u_m,w] = h(g(u), w)

Concrete Example (Hindmarsh Rose)

This example currently is causing too much of slow down, but larger systems are very slow. Hopefully it can help ground the discussion:
\dot{U} = F(U,w) = \left[\begin{array}{l} {w}_1 U_2+w_2 U_1^3+w_3 U_1^2+w_4 U_3 \\ w_5+w_6 U_1^2+w_7 U_2 \\ w_8 U_1+w_9+w_{10} U_3 \end{array}\right]
\nabla_U F(U,w) = \left[\begin{array}{ccc} 3 U_1^2 w_2 + 2U_1w_3 & {w}_1 & w_4 \\ 2U_1 w_6 & w_7 & 0 \\ w_8 & 0 & w_{10} \end{array}\right]
I want to programmatically recognize that
g(U) = \left[\begin{array}{c} 3U_1^2\\ 2U_1 \end{array}\right]
Then we can precompute \{y^{(m)}\}_{m=1}^M := \{g(u^{(m)})\}_{m=1}^M and then quickly compute
\begin{align*}\nabla_UF(u^{(m)},w) &= h(g(u^{(m)},w) = h(y^{(m)},w) \\ &= \left[\begin{array}{ccc} y^{(m)}_1 w_2 + y_2^{(m)}w_3 & {w}_1 & w_4 \\ y_2^{(m)} w_6 & w_7 & 0 \\ w_8 & 0 & w_{10} \end{array}\right] \end{align*}

Ideas on how to start

  • There should be a way to isolate each expressions that contain just U and scalars. We would need to leverage when operators are/aren’t associative, commutative, etc… In the example this would result in [3U_1^2, 2U_1,2U_1].
  • Then we define a vector of the unique expressions. In the example: [3U_1^2, 2U_1,] .
  • Substitute the y_i in the original equation for each expression to define h(y,w).

I am having a hard time getting started in the Symbolics API with what I can see from the documentation. Perhaps there is way to do this in sympy through SymPy.jl? Any help or starting points would be a great help!

This is just common subexpression elimination?

@ChrisRackauckas I think so! This is readily available in pysym (link goes to relevant documentation). It looks like others have used SymbolicUtils.jl and the (unmaintained) SymbolicCodegen.jl to do CSE. It can be found in this show and tell.

My code already has symbolic expression from Symbolics.jl, Is there a way to use the high level module to:

  1. Find the common subexpressions
  2. Build Julia functions of each subexpression
  3. Substitute the expressions back into the original
  4. Build a function out of the final expression

Or should I bring in some SymPy.jl

If you use build_function, then you just add cse=true. We don’t do it by default though because there’s an unknown bug according to @shashi IIRC, but I don’t know what it is or how to trigger it so if you can help figure out what that is we can just make it standard (since it’s a standard codegen optimization, even Julia will run CSE on the generated code).

1 Like