Motivation
I have a system of differential equations
\dot{U} = F(U,w)
where U\in \mathbb{R}^D are dependent variables and w \in\mathbb{R}^J are parameters. And I can quickly build (in-place) Julia functions from Symbolics.jl and ModelingToolkit.jl of the following functions of U and w:
F(U,w), \; \nabla_uF(U,w), \;\nabla_wF(U,W), and \nabla_w[\nabla_uF](U,w)
I then have data \{u^{(m)}\}_{m=1}^M, where M\in \mathbb{N} is large. I don’t want assume anything about F a priori, but I need to evaluate all four of these functions at every u^{(m)} for different w during an optimization routine. This is quickly becoming expensive…
Goal
I would like to further speed up evaluating these functions by auto detecting
\nabla_uF[u_m,w] = h(g(u), w)
Concrete Example (Hindmarsh Rose)
This example currently is causing too much of slow down, but larger systems are very slow. Hopefully it can help ground the discussion:
\dot{U} = F(U,w) = \left[\begin{array}{l} {w}_1 U_2+w_2 U_1^3+w_3 U_1^2+w_4 U_3 \\
w_5+w_6 U_1^2+w_7 U_2 \\
w_8 U_1+w_9+w_{10} U_3
\end{array}\right]
\nabla_U F(U,w) = \left[\begin{array}{ccc} 3 U_1^2 w_2 + 2U_1w_3 & {w}_1 & w_4 \\
2U_1 w_6 & w_7 & 0 \\
w_8 & 0 & w_{10}
\end{array}\right]
I want to programmatically recognize that
g(U) = \left[\begin{array}{c}
3U_1^2\\ 2U_1
\end{array}\right]
Then we can precompute \{y^{(m)}\}_{m=1}^M := \{g(u^{(m)})\}_{m=1}^M and then quickly compute
\begin{align*}\nabla_UF(u^{(m)},w) &= h(g(u^{(m)},w) = h(y^{(m)},w) \\
&= \left[\begin{array}{ccc} y^{(m)}_1 w_2 + y_2^{(m)}w_3 & {w}_1 & w_4 \\
y_2^{(m)} w_6 & w_7 & 0 \\
w_8 & 0 & w_{10}
\end{array}\right]
\end{align*}
Ideas on how to start
- There should be a way to isolate each expressions that contain just U and scalars. We would need to leverage when operators are/aren’t associative, commutative, etc… In the example this would result in [3U_1^2, 2U_1,2U_1].
- Then we define a vector of the unique expressions. In the example: [3U_1^2, 2U_1,] .
- Substitute the y_i in the original equation for each expression to define h(y,w).
I am having a hard time getting started in the Symbolics API with what I can see from the documentation. Perhaps there is way to do this in sympy through SymPy.jl? Any help or starting points would be a great help!