How to search the AST for specific types and functions?


#1

Motivation

I am writing a simple differential equation solver using the finite element method. While I have done this in the past, I often find that there is a big chasm between the mathematical form and the implemented code – specifically in the ‘‘assembly’’ of the linear system of equations. I wind up with a piece of code for the Laplace problem, for linear elasticity, etc. Of course, large sections of code are reused, but I am looking now to bridge that ‘‘last mile’’ of connectivity. A large part of the inspiration behind me pursuing this is due to the fenics project (fenicsproject.org).

I can illustrate this with an example. The math is simple enough, but it is not really essential to follow it in detail. In the parlance of the finite element method, say I am looking for a function u that satisfies an integral equation for all ‘‘test’’ functions v.

\int \nabla(v) \cdot \nabla (u) dv = 0

If we approximate u and v as a linear combination of a finite function basis, we can ask for the coefficients of this basis such as to satisfy the above equation.

So the equation above specifies a sort of ‘‘recipe’’ to assemble the linear system of equations.

Going from math to code

I may be wrong, but I feel like this is a job that can be achieved via metaprogramming. I have defined types for u and v (in the parlance of the finite element method, TrialFunction and TestFunction respectively.) I have defined the action of operators like \nabla in terms of how they manipulate my types.

I also have the basic infrastructure of solving such a problem in place. i.e. given a mesh discretization, I can compute all the things necessary like my basis functions and their derivatives, gaussian quadrature etc. I know the exact form of a Julia expression that I can evaluate on one ‘‘element’’ of my mesh to obtain its contribution to the linear system.

Now I want to write a (hopefully?) simple macro that can take my math expression and turn it into a corresponding Julia expression.

Ideally, I would like the macro to behave something like this for the example above:

@assemble ∇(v)[k]∇(u)[k]

would give me

K[I,J] = basis[I][k]*basis[J][k]

(the repeated indices in the @assemble statement just tell me that the operation is a ‘dot-product’ and is similar to the @tensor macro from TensorOperations.)

Now, after all this long-winded chit-chat here are some more specific questions:

Questions

  • Is a macro the way to go for something like this? If yes…
  • Linearity: Is there a way that I can convert :(u(v + w)) into :(uv + uw)? This will allow me to write more general equations as input without having to simplify them down by hand first. From the mathematical side, this is also essential, so it would also act as a good check that the input is valid. As a simple example, the \nabla operator is linear, so if I encounter a call to + after a call to , then I can switch their order. How can I do such a thing in my AST?
  • my variables u and v from above have specific types. The fields of this type tell me, for example, what order derivatives I should use in my calculations. What is the most effective way of searching through the AST of my expression above looking for instances of my types?

Apologies if my question is a little too long winded, but I thought some context might help :slight_smile:


#2

Not specifically macro, but expressions themselves seem to be appropriate in this case.

You may find Espresso.jl helpful in this case:

pattern_src = :(_f(_a + _b))        # rewrite expressions similar to this
pattern_dst = :(_f(_a) + _f(_b))   # to look like this
rewrite(:(u(v + w)), pattern_src, pattern_dst)
# ==> :(u(v) + u(w))

(use rewrite_all() to apply the rule to all sub-expressions recursively; see also other utilities here).

Types go beyond symbolic calculations. In Espresso.jl I used expression graphs for this kind of things:

ex = :(v * w + (log.(v) + exp.(w)))
g = ExGraph(ex; v=rand(3,3), w=rand(3,3));  # v and w are "example values" used to calculate types of all vars in graph
evaluate!(g)                                 # evaluate graph so that each node gets a value
println(g)
# ==>
# ExGraph
#   ExNode{input}(v = v | <Array{Float64,2}>)
#   ExNode{input}(w = w | <Array{Float64,2}>)
#   ExNode{call}(tmp394 = v * w | <Array{Float64,2}>)
#   ExNode{bcast}(tmp395 = log.(v) | <Array{Float64,2}>)
#   ExNode{bcast}(tmp396 = exp.(w) | <Array{Float64,2}>)
#   ExNode{call}(tmp397 = tmp395 + tmp396 | <Array{Float64,2}>)
#   ExNode{call}(tmp398 = tmp394 + tmp397 | <Array{Float64,2}>)

Not sure if it works with your TrialFunction and TestFunction types, but you can try it.

You can then iterate over nodes to find variables of the type in question:

for nd in g
    println(typeof(getvalue(nd)))
end

Expression x[k] * y[k] looks more like scalar product of vectors, even given type information it’s unclear whether it should be translated to:

K[i, j] = x[i, k] * y[k, j]    # matrix multiplication
K[i, j] = x[i, k] * y[j, k]
K[i, j] = x[k, i] * y[k, j]
K[i] = x[k] * y[k]            # summation over columns/rows

Usually exact meaning of operation in Einstein notation is defined by both - right-hand and left-hand side of an expression. This also makes it extremely hard to compose expressions with more than one operation on RHS.

Espresso.jl used to work with Einstein notation in versions up to v3.0.0 (search for EinGraph). Also Einsum.jl might be helpful, though both options are currently outdated and unmaintained.


#3

Thanks for your reply.

Yes, that does seem to do what I want! I will check it out with a few examples.

This was part of the reason why I decided to specify types for TrialFunction and TestFunction. After all the calculus operations like \nabla and div, the fields of my TrialFunction tell me the tensor rank (and how many derivatives to take for my basis function). Moreover, I figured it would be reasonable to expect the input equation to have exact summation notation form so there can be no ambiguity.

So if the user input has:

@assemble ∇(v)[k]∇(u)[k]

the user is implying that \nabla(v) is rank 1, and my code can ideally verify this by studying typeof(∇(v)).

Also, I figured that I wouldn’t bother trying to actually convert the index notation into code, but rather use the @tensor macro from the TensorOperations package for that purpose.

Which also raises another question: Can a macro return a macro?

Thanks for your time!


#4

This is not possible in a macro. Macros only know how the expression is “spelled”, because they execute right after parsing — they don’t have any idea of what values symbols are bound to or what types they have.

If you need type information, then you need something that executes at runtime (or, in rare cases, at compile time via generated functions).

I don’t think you want macros at all, here. I think you maybe want a function like assemble(∇(v)*∇(u)), where ∇(v) and ∇(u) construct symbolic objects from test-function objects v and u, * multiplies them to construct another symbolic object, and assemble takes this symbolic structure and runs some computation on it to assemble your matrix.

(@assemble ∇(v)[k]∇(u)[k] would never work in any case because ∇(v)[k]∇(u)[k] is not a valid parseable expression.)


#5

This is key. Macros transform syntax into other syntax. You can however have a macro generate an expression that that contains a call to an overloaded function, so that the types of the values are involved in dispatch.

You could circumvent this issue using a string macro, https://docs.julialang.org/en/v1/manual/metaprogramming/#Non-Standard-String-Literals-1. Not saying that’s necessarily a good idea, but maybe it’s an option.

If you’re feeling adventurous, you could also try experimenting with https://github.com/jrevels/Cassette.jl. Cassette allows you to essentially intercept and modify function calls that match specific signatures in a given user-defined context (in your case, say, the ‘assembly’ context). Again, not sure if it’s the right fit for this application though.


#6

I see. So a macro is a good idea if I am purely going to manipulate syntax. I suppose I could approach the problem from that angle too… I would just need to attach very specific meanings to the symbols I use in my expressions. However,

I reckon this would be the more direct way for me.

Thanks for your inputs!


#7

Here is an example of a recursive function I made to search for and substitute certain types of numerical values and convert them to a specified type:

It can be used like this

julia> SyntaxTree.sub(Float64,:(2x^2-1/2))
:(2.0 * x ^ 2 - 1.0 / 2.0)

You could do something similar for your requirements, so I would just write a custom recursive function which walks through your expression and checks for your relevant types and does whatever you need to do with it.


#8

This is really useful, especially the example recursive function. Thanks!