# Simplify an Expr

In the context of code generation, I want to write a code simplification routine.
Input and output are of type Expr.
A[i] and b[i] can take as sparsity pattern `:null, :plusone, :minusone, :other`

Example:
`expr = :(A[1] = A[1] + A[2] * b[3])`
Depending on actual values, `expr` can be simplified to

``````:(A[1] = A[1] + A[2] * b[3])  # if all have value :other
:(A[1] = A[1] + b[3])         # if A[2] == :plusone
...
nothing::Nothing              # if A[2] * b[3] == 0
``````

kmsquire/Match.jl: Advanced Pattern Matching for Julia (github.com) looks promising for simple, fixed-form expressions which can be solved by hand.
MasonProtter/Rewrite.jl: Term rewriting in Julia (github.com) appears more general, but works with symbolic expressions. So maybe as intermediate form.
Other ideas?

I think what you need could be done with Metatheory.jl and it’s probably what you’re looking for. Otherwise, if you prefer to do it similarly to Rewrite.jl there is SymbolicUtils package, that works with symbolic expressions. Finally in Symbolics.jl there is support for symbolic arrays, not sure whether it includes sparsity patterns ATM, but you could check.

So depending on how “low-level” you want to go with simplifications, it’s probably one of these three packages.

1 Like

Espresso.jl might also be useful for you, especially rewrite.jl.

1 Like

Before calling in the cavalry, I looked a bit closer and finished with some rules working directly on the expression.

``````
"""
Simplification of elementary expressions in sparse linear solver loops
"""
simple(e) = error("not expected type: \$(typeof(e))")
simple(e::Number) = e
function simple(e::Symbol)
ev = eval(e)
if     ev ==  0 return  0
elseif ev ==  1 return  1
elseif ev == -1 return -1
else            return  e
end
end
function simple(e::Expr)
if eval(e.args[1]) ≈ eval(e.args[2]) # (redundant operation)
return nothing::Nothing
else
e.args[2] = simple(e.args[2])
end
e.args[2] = simple(e.args[2])
e.args[3] = simple(e.args[3])
ev2 = eval(e.args[2])
ev3 = eval(e.args[3])
if e.args[1] == :*
if     ev2 == 1                 # 1 * a -> a
e = e.args[3]
elseif ev3 == 1                 # a * 1 -> a
e = e.args[2]
elseif ev2 == -1                # -1 * a -> -a
e = Expr(:call, :-, e.args[3])
elseif ev3 == -1                # a * -1 -> -a
e = Expr(:call, :-, e.args[2])
end
elseif e.args[1] == :/
if ev3 == 1                     # a / 1 -> a
e = e.args[2]
elseif ev3 == -1                # a / -1 -> -a
e = Expr(:call, :-, e.args[2])
end
elseif e.args[1] == :+
if     ev2 == 0                 # 0 + a -> a
e = e.args[3]
elseif ev3 == 0                 # a + 0 -> a
e = e.args[2]
end
end
end
return e
end

# Tests
a = π
b = sqrt(2)
a0 = 0
ap1 = 1
am1 = -1
simple(:(a = a + (a * a0 * b)/a))   # nothing::Nothing
simple(:(a = ap1 * am1))            # :(a = -1)
simple(:(a = ap1 * b))              # :(a = b)
simple(:(a = a + am1 * b))          # :(a = a + -b)
simple(:(a = a + am1 / b))          # :(a = a + -1 / b)
simple(:(a = 2*a + a / b))          # (unchanged)

``````