Simplify an Expr

In the context of code generation, I want to write a code simplification routine.
Input and output are of type Expr.
A[i] and b[i] can take as sparsity pattern :null, :plusone, :minusone, :other

Example:
expr = :(A[1] = A[1] + A[2] * b[3])
Depending on actual values, expr can be simplified to

:(A[1] = A[1] + A[2] * b[3])  # if all have value :other
:(A[1] = A[1] + b[3])         # if A[2] == :plusone
...
nothing::Nothing              # if A[2] * b[3] == 0

kmsquire/Match.jl: Advanced Pattern Matching for Julia (github.com) looks promising for simple, fixed-form expressions which can be solved by hand.
MasonProtter/Rewrite.jl: Term rewriting in Julia (github.com) appears more general, but works with symbolic expressions. So maybe as intermediate form.
Other ideas?

I think what you need could be done with Metatheory.jl and it’s probably what you’re looking for. Otherwise, if you prefer to do it similarly to Rewrite.jl there is SymbolicUtils package, that works with symbolic expressions. Finally in Symbolics.jl there is support for symbolic arrays, not sure whether it includes sparsity patterns ATM, but you could check.

So depending on how “low-level” you want to go with simplifications, it’s probably one of these three packages.

1 Like

Espresso.jl might also be useful for you, especially rewrite.jl.

1 Like

Thanks for all your suggestions!
Before calling in the cavalry, I looked a bit closer and finished with some rules working directly on the expression.


"""
Simplification of elementary expressions in sparse linear solver loops
"""
simple(e) = error("not expected type: $(typeof(e))")
simple(e::Number) = e
function simple(e::Symbol)
    ev = eval(e)
    if     ev ==  0 return  0
    elseif ev ==  1 return  1
    elseif ev == -1 return -1
    else            return  e
    end
end
function simple(e::Expr)
    if e.head == :(=)
        if eval(e.args[1]) ≈ eval(e.args[2]) # (redundant operation)
            return nothing::Nothing
        else
            e.args[2] = simple(e.args[2])
        end
    elseif e.head == :call 
        e.args[2] = simple(e.args[2])
        e.args[3] = simple(e.args[3])
        ev2 = eval(e.args[2])
        ev3 = eval(e.args[3])
        if e.args[1] == :*
            if     ev2 == 1                 # 1 * a -> a
                e = e.args[3]
            elseif ev3 == 1                 # a * 1 -> a
                e = e.args[2] 
            elseif ev2 == -1                # -1 * a -> -a
                e = Expr(:call, :-, e.args[3])
            elseif ev3 == -1                # a * -1 -> -a
                e = Expr(:call, :-, e.args[2]) 
            end
        elseif e.args[1] == :/
            if ev3 == 1                     # a / 1 -> a
                e = e.args[2]
            elseif ev3 == -1                # a / -1 -> -a
                e = Expr(:call, :-, e.args[2]) 
            end
        elseif e.args[1] == :+
            if     ev2 == 0                 # 0 + a -> a
                e = e.args[3]
            elseif ev3 == 0                 # a + 0 -> a
                e = e.args[2]
            end
        end
    end
    return e
end

# Tests
a = π
b = sqrt(2)
a0 = 0
ap1 = 1
am1 = -1
simple(:(a = a + (a * a0 * b)/a))   # nothing::Nothing
simple(:(a = ap1 * am1))            # :(a = -1)
simple(:(a = ap1 * b))              # :(a = b)
simple(:(a = a + am1 * b))          # :(a = a + -b)
simple(:(a = a + am1 / b))          # :(a = a + -1 / b)
simple(:(a = 2*a + a / b))          # (unchanged)