Using Metatheory.jl to simplify algebra expression

Hi,

I am trying to use Metatheory.jl to simplify my algebra expressions (the expressions have type Expr), but did not get the results I expected. For example, when I tried to simplify the Expr :(a+(a-b)), I got :(a+(a-b)) instead of :(2a-b). However, it succesfully simplified :(a+b+a) to (2a+b). I would really appreciate it if you can give me some guidances on how to improve my code to get the results I want. Here is my code:

using Metatheory, Metatheory.Library,Metatheory.Schedulers
mult_t = @commutative_monoid (*) 1
plus_t = @commutative_monoid (+) 0

minus_t = @theory a b begin
  a - a --> 0
  a - b --> a + (-1 * b)
  -a --> -1 * a
   a + (-b) --> a + (-1 * b)
end

mulplus_t = @theory a b c begin
  a + a --> 2 * a
  0 * a --> 0
  a * 0 --> 0
  a * (b + c) == ((a * b) + (a * c))
  a + (b * a) --> ((b + 1) * a)
end

fold_t = @theory a b begin
  -(a::Number) => -a
  a::Number + b::Number => a + b
  a::Number * b::Number => a * b
  a::Number^b::Number => begin
    b < 0 && a isa Int && (a = float(a))
    a^b
  end
  a::Number / b::Number => a / b
end

cas = fold_t ∪ mult_t ∪ plus_t ∪ minus_t ∪ mulplus_t

canonical_t = @theory x y n xs ys begin
  # restore n-arity
  (x * x) --> x^2
  (x^n::Number * x) --> x^(n + 1)
  (x * x^n::Number) --> x^(n + 1)
  (x + (+)(ys...)) --> +(x, ys...)
  ((+)(xs...) + y) --> +(xs..., y)
  (x * (*)(ys...)) --> *(x, ys...)
  ((*)(xs...) * y) --> *(xs..., y)

  (*)(xs...) => Expr(:call, :*, sort!(xs; lt = customlt)...)
  (+)(xs...) => Expr(:call, :+, sort!(xs; lt = customlt)...)
end

function simplify(ex; steps = 4)
  params = SaturationParams(
    scheduler = ScoredScheduler,
    eclasslimit = 5000,
    timeout = 7,
    schedulerparams = (1000, 5, Schedulers.exprsize),
    #stopwhen=stopwhen,
  )
  hist = UInt64[]
  push!(hist, hash(ex))
  for i in 1:steps
    g = EGraph(ex)
    @profview_allocs saturate!(g, cas, params)
    ex = extract!(g, simplcost)
    ex = rewrite(ex, canonical_t)
    if !TermInterface.istree(ex)
      return ex
    end
    if hash(ex) ∈ hist
      println("loop detected $ex")
      return ex
    end
    println(ex)
    push!(hist, hash(ex))
  end
end

simplify(:(a+(a-b))
simplify(:a+b+a)

Regards