Flattening nested operators in `Expr`

I have an Expr that comes to me looking like this:

:((+)((+)(1,2),3))

except that my actual Exprs have 100s of terms, each of which is a complicated expression (which may involve products with further sums that I don’t want to expand at this level).

Now I actually want to iterate over each term individually and apply some function to each term. Is there an elegant/robust way to do so?

My instinct is to try to turn that expression into a “flattened” version making use of the associativity:

:((+)(1,2,3))

When I do that, I can just iterate over expr.args[2:end]. Is this the best approach?

Ideally, I’d like to be able to handle either form without having to worry about it.

An expression is basically a tree:

julia> e = :((+)((+)(1,2),3))
:((1 + 2) + 3)

julia> Meta.dump(e)
Expr
  head: Symbol call
  args: Array{Any}((3,))
    1: Symbol +
    2: Expr
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol +
        2: Int64 1
        3: Int64 2
    3: Int64 3

In order to process them, you might want to look at MacroTools.jl which has some nice tools for walking trees:

julia> using MacroTools

julia> MacroTools.postwalk(x -> if x isa Number; 2*x else x end, e)
:((2 + 4) + 6)

Thanks for your reply. I certainly use MacroTools most of the time. The problem is how to stop it from going too far.

I guess my question in your framing would be: How do I stop postwalk (or prewalk) from going further down a branch in that tree if any of the operations is something other than (+)?

So, for example, if my Expr were

:((+)((+)(1,a*(b+c)),3))

how would I ensure that my processing function receives as arguments 1, a*(b+c), and 3, but not (b+c), b, or c?

Here’s my solution:

isadd(x) = MacroTools.isexpr(x, :call) && x.args[1] ∈ ((+), :+)

function flatten_add!(expr)
    while any(isadd, expr.args[2:end])
        args = expr.args[2:end]
        i₊ = findfirst(isadd, args)
        args′ = [:+; args[1:i₊-1]; args[i₊].args[2:end]; args[i₊+1:end]]
        expr.args[:] = args′[1:length(expr.args)]
        append!(expr.args, args′[1+length(expr.args):end])
    end
    expr
end

This is indeed a bit more involved. You could either use the underlying walk function that is used by both prewalk and postwalk under the hood or insert and strip a marker to prevent prewalk from stepping further down in your expression:

julia> using MacroTools

julia> e = :((+)((+)(1,a*(b+c)),3))
:((1 + a * (b + c)) + 3)

julia> struct StopHere
           stuff
       end

julia> MacroTools.postwalk(x -> if x isa StopHere x.stuff else x end,
                           MacroTools.prewalk(x -> if iscall(x, :+) x else println(x); StopHere(x) end,
                                              e))
+
+
1
a * (b + c)
3
:((1 + a * (b + c)) + 3)
1 Like

Greetings,
I understand the solution, but I just want to understand why args[1:i₊-1] 's output is different after the @show in the code. Thanks

using MacroTools
isadd(x) = MacroTools.isexpr(x, :call) && x.args[1] ∈ ((+), :+)
function flatten_add!(expr)
    while any(isadd, expr.args[2:end])
        args = expr.args[2:end]
        i₊ = findfirst(isadd, args)
        @show args[1:i₊-1]  #-------------outputs Any[]
        args′ = [:+; args[1:i₊-1]; args[i₊].args[2:end]; args[i₊+1:end]]
        @show args′[2]   #----------- outputs 1
        expr.args[:] = args′[1:length(expr.args)]
        append!(expr.args, args′[1+length(expr.args):end])
    end
    expr
end
flatten_add!(:((+)((+)(1,a*(b+c)),3)))

First, let’s just review the reason for args[1:i₊-1]. It’s to capture any part of the expression before the first “add” within the outer “add”. In your case, the expression is :((+)((+)(1,a*(b+c)),3)). The very first thing inside the outer “add” is another “add”, so it makes sense that args[1:i₊-1] is empty. If I add an x before that first nested “add”, like :((+)(x,(+)(1,a*(b+c)),3)), you get a non-empty array there.

Next, step back and remember what a semicolon does inside a vector constructor: it does vertical concatenation. But concatenating an empty array with other things is (almost) the same as concatenating those other things. Take a simple example:

julia> [1; []; [2, 3, 4]; 5]
5-element Vector{Any}:
 1
 2
 3
 4
 5

julia> [1; [2, 3, 4]; 5]
5-element Vector{Int64}:
 1
 2
 3
 4
 5

The only difference is the eltype of those vectors, but the elements themselves are the same in both cases. (In performance-critical code, of course, we would work harder for type stability, but I’m not worried.)

So, when args′ is constructed by concatenating that empty vector, it basically just disappears, and args′[2] is actually referring to args[i₊].args[2].