Improving type inference for `x` where `Meta.isexpr(x, ...)`

Type inference struggles a bit with the following:

julia> a = :(x = y = 1);

julia> a.args[2]
:(y = 1)

julia> function f(a::Expr)
           ex = a.args[2]
           if Meta.isexpr(ex, :(=))
               return ex
           else
               return nothing
           end
       end;

julia> @code_warntype f(a)
MethodInstance for f(::Expr)
  from f(a::Expr) in Main at REPL[37]:1
Arguments
  #self#::Core.Const(f)
  a::Expr
Locals
  ex::Any
Body::Any
1 ─ %1 = Base.getproperty(a, :args)::Vector{Any}
β”‚        (ex = Base.getindex(%1, 2))
β”‚   %3 = Base.Meta.isexpr::Core.Const(Base.isexpr)
β”‚   %4 = ex::Any
β”‚   %5 = (%3)(%4, :(=))::Bool
└──      goto #3 if not %5
2 ─      return ex
3 ─      return Main.nothing

So, ex is inferred as Any as the return type even though it must be an Expr since isexpr(ex, :(=)) holds at that point.

Is there something I can do to help the compiler in situations such as this?

Sometimes, typing out the problem is all that’s needed.

julia> function f(a::Expr)
           ex = a.args[2]
           if ex isa Expr && Meta.isexpr(ex, :(=))
               return ex
           else
               return nothing
           end
       end;

julia> @code_warntype f(a)
MethodInstance for f(::Expr)
  from f(a::Expr) in Main at REPL[43]:1
Arguments
  #self#::Core.Const(f)
  a::Expr
Locals
  ex::Any
Body::Union{Nothing, Expr}
1 ─ %1 = Base.getproperty(a, :args)::Vector{Any}
β”‚        (ex = Base.getindex(%1, 2))
β”‚   %3 = (ex isa Main.Expr)::Bool
└──      goto #4 if not %3
2 ─ %5 = Base.Meta.isexpr::Core.Const(Base.isexpr)
β”‚   %6 = ex::Expr
β”‚   %7 = (%5)(%6, :(=))::Bool
└──      goto #4 if not %7
3 ─      return ex::Expr
4 β”„      return Main.nothing
1 Like

In Julia 1.7 and higher, this kind of type constraint should be propagated automatically, and you don’t need to add such annotations.

julia> versioninfo()
Julia Version 1.7.1
Commit ac5cc99908 (2021-12-22 19:35 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin19.5.0)
  CPU: Intel(R) Core(TM) i5-1038NG7 CPU @ 2.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, icelake-client)

julia> function f(a::Expr)
           ex = a.args[2]
           if ex isa Expr && Meta.isexpr(ex, :(=))
               return ex
           else
               return nothing
           end
       end;

julia> @code_warntype f(:(a = nothing))
MethodInstance for f(::Expr)
  from f(a::Expr) in Main at REPL[1]:1
Arguments
  #self#::Core.Const(f)
  a::Expr
Locals
  ex::Any
Body::Union{Nothing, Expr}
1 ─ %1 = Base.getproperty(a, :args)::Vector{Any}
β”‚        (ex = Base.getindex(%1, 2))
β”‚   %3 = (ex isa Main.Expr)::Bool
└──      goto #4 if not %3
2 ─ %5 = Base.getproperty(Main.Meta, :isexpr)::Core.Const(Base.isexpr)
β”‚   %6 = ex::Expr
β”‚   %7 = (%5)(%6, :(=))::Bool
└──      goto #4 if not %7
3 ─      return ex::Expr
4 β”„      return Main.nothing
1 Like

Ops, sorry, I just noticed this particular example doesn’t work on 1.7 – it’s still fixed on 1.8 and higher. Especially, we need this PR to nicely handle this kind of situation (where the conditional object is represented as SSAValue).

julia> versioninfo()
Julia Version 1.8.0-beta2.9
Commit e077335b93 (2022-03-17 01:13 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin21.3.0)
  CPU: 8 Γ— Intel(R) Core(TM) i5-1038NG7 CPU @ 2.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, icelake-client)
  Threads: 1 on 8 virtual cores
Environment:
  JULIA_PROJECT = @.
  JULIA_EDITOR = code
  JULIA_PKG_DEVDIR = /Users/aviatesk/julia/packages

julia> function f(a::Expr)
           ex = a.args[2]
           if Meta.isexpr(ex, :(=))
               return ex
           else
               return nothing
           end
       end
f (generic function with 1 method)

julia> @code_warntype f(:(a = nothing))
MethodInstance for f(::Expr)
  from f(a::Expr) in Main at REPL[4]:1
Arguments
  #self#::Core.Const(f)
  a::Expr
Locals
  ex::Any
Body::Union{Nothing, Expr}
1 ─ %1 = Base.getproperty(a, :args)::Vector{Any}
β”‚        (ex = Base.getindex(%1, 2))
β”‚   %3 = Base.Meta.isexpr::Core.Const(Base.isexpr)
β”‚   %4 = ex::Any
β”‚   %5 = (%3)(%4, :(=))::Bool
└──      goto #3 if not %5
2 ─      return ex::Expr
3 ─      return Main.nothing

So please wait a bit while for the release of 1.8 :slight_smile:

2 Likes

Very nice