In the context of code generation, I want to write a code simplification routine.
Input and output are of type Expr.
A[i] and b[i] can take as sparsity pattern :null, :plusone, :minusone, :other
Example: expr = :(A[1] = A[1] + A[2] * b[3])
Depending on actual values, expr can be simplified to
:(A[1] = A[1] + A[2] * b[3]) # if all have value :other
:(A[1] = A[1] + b[3]) # if A[2] == :plusone
...
nothing::Nothing # if A[2] * b[3] == 0
I think what you need could be done with Metatheory.jl and it’s probably what you’re looking for. Otherwise, if you prefer to do it similarly to Rewrite.jl there is SymbolicUtils package, that works with symbolic expressions. Finally in Symbolics.jl there is support for symbolic arrays, not sure whether it includes sparsity patterns ATM, but you could check.
So depending on how “low-level” you want to go with simplifications, it’s probably one of these three packages.
Thanks for all your suggestions!
Before calling in the cavalry, I looked a bit closer and finished with some rules working directly on the expression.
"""
Simplification of elementary expressions in sparse linear solver loops
"""
simple(e) = error("not expected type: $(typeof(e))")
simple(e::Number) = e
function simple(e::Symbol)
ev = eval(e)
if ev == 0 return 0
elseif ev == 1 return 1
elseif ev == -1 return -1
else return e
end
end
function simple(e::Expr)
if e.head == :(=)
if eval(e.args[1]) ≈ eval(e.args[2]) # (redundant operation)
return nothing::Nothing
else
e.args[2] = simple(e.args[2])
end
elseif e.head == :call
e.args[2] = simple(e.args[2])
e.args[3] = simple(e.args[3])
ev2 = eval(e.args[2])
ev3 = eval(e.args[3])
if e.args[1] == :*
if ev2 == 1 # 1 * a -> a
e = e.args[3]
elseif ev3 == 1 # a * 1 -> a
e = e.args[2]
elseif ev2 == -1 # -1 * a -> -a
e = Expr(:call, :-, e.args[3])
elseif ev3 == -1 # a * -1 -> -a
e = Expr(:call, :-, e.args[2])
end
elseif e.args[1] == :/
if ev3 == 1 # a / 1 -> a
e = e.args[2]
elseif ev3 == -1 # a / -1 -> -a
e = Expr(:call, :-, e.args[2])
end
elseif e.args[1] == :+
if ev2 == 0 # 0 + a -> a
e = e.args[3]
elseif ev3 == 0 # a + 0 -> a
e = e.args[2]
end
end
end
return e
end
# Tests
a = π
b = sqrt(2)
a0 = 0
ap1 = 1
am1 = -1
simple(:(a = a + (a * a0 * b)/a)) # nothing::Nothing
simple(:(a = ap1 * am1)) # :(a = -1)
simple(:(a = ap1 * b)) # :(a = b)
simple(:(a = a + am1 * b)) # :(a = a + -b)
simple(:(a = a + am1 / b)) # :(a = a + -1 / b)
simple(:(a = 2*a + a / b)) # (unchanged)