Type-inference with two variables that must have the same type

julia> fs() = rand(Bool) ? (1,1) : (1.0, 1.0)
fs (generic function with 1 method)

julia> f(fs, i) = fs()[i]
f (generic function with 1 method)

julia> g(fs) = (f(fs, 1), f(fs, 2))
g (generic function with 1 method)

julia> @code_warntype g(fs)
MethodInstance for g(::typeof(fs))
  from g(fs) @ Main REPL[3]:1
Arguments
  #self#::Core.Const(g)
  fs::Core.Const(fs)
Body::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
1 ─ %1 = Main.f(fs, 1)::Union{Float64, Int64}
β”‚   %2 = Main.f(fs, 2)::Union{Float64, Int64}
β”‚   %3 = Core.tuple(%1, %2)::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
└──      return %3

Would it be possible to nudge the compiler to realize that the types of f(fs,1) must be the same as that of f(fs, 2), which would make the result a union of simple Tuples instead of a Tuple of unions? I can’t add type assertions to f and g.

But that’s not the case:

julia> g(fs) 
(1, 1)

julia> g(fs) 
(1.0, 1)

Is that what you mean? The types are the same. Also the type inferred is correct, it cannot be narrower, as shown above as fs is called twice.

**julia>** T1=Tuple{Union{Float64, Int64}, Union{Float64, Int64}}

Tuple{Union{Float64, Int64}, Union{Float64, Int64}}

**julia>** T2=Union{Tuple{Int,Float64},Tuple{Int,Int},Tuple{Float64,Int},Tuple{Float64,Float64}}

Union{Tuple{Float64, Float64}, Tuple{Float64, Int64}, Tuple{Int64, Float64}, Tuple{Int64, Int64}}


**julia>** T1==T2

true

Now if you want fs to be called once:

julia> g2(fs)=(fs()...,)
julia> @code_warntype g2(fs)
MethodInstance for g2(::typeof(fs))
  from g2(fs) in Main at REPL[15]:1
Arguments
  #self#::Core.Const(g2)
  fs::Core.Const(fs)
Body::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
1 ─ %1 = (fs)()::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
β”‚   %2 = Core._apply_iterate(Base.iterate, Core.tuple, %1)::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
└──      return %2

My apologies, the example was not well-chosen. This is more like what I want:

julia> using LinearAlgebra

julia> fs(A::AbstractMatrix) = isdiag(A) ? (1,1) : (1.0, 1.0)
fs (generic function with 1 method)

julia> f(A, fs, i) = fs(A)[i]
f (generic function with 2 methods)

julia> g(A, fs) = (f(A, fs, 1), f(A, fs, 2))
g (generic function with 2 methods)

julia> @code_warntype g(ones(1,1), fs)
MethodInstance for g(::Matrix{Float64}, ::typeof(fs))
  from g(A, fs) @ Main REPL[15]:1
Arguments
  #self#::Core.Const(g)
  A::Matrix{Float64}
  fs::Core.Const(fs)
Body::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
1 ─ %1 = Main.f(A, fs, 1)::Union{Float64, Int64}
β”‚   %2 = Main.f(A, fs, 2)::Union{Float64, Int64}
β”‚   %3 = Core.tuple(%1, %2)::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
└──      return %3

In this example, the return type of fs(A) and g(A, fs) is known from the values of the arguments, unlike the random example above. Perhaps constant propagation is required to indicate that the same branch is chosen in both the indexing operations?

But how will the compiler know that isdiag isn’t somehow a bit random? Or that A hasn’t been mutated behind the scenes, somehow.

Wouldn’t the function have to be known to be β€˜pure’, and the input immutable?

1 Like

The mutability of the argument is a red herring, as we may check:

julia> fs(x::Int) = x > 2 ? (1,1) : (1.0, 1.0)
fs (generic function with 1 method)

julia> f(x, fs, i) = fs(x)[i]
f (generic function with 1 method)

julia> g(x, fs) = (f(x, fs, 1), f(x, fs, 2))
g (generic function with 1 method)

julia> @code_warntype g(2, fs)
MethodInstance for g(::Int64, ::typeof(fs))
  from g(x, fs) @ Main REPL[3]:1
Arguments
  #self#::Core.Const(g)
  x::Int64
  fs::Core.Const(fs)
Body::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
1 ─ %1 = Main.f(x, fs, 1)::Union{Float64, Int64}
β”‚   %2 = Main.f(x, fs, 2)::Union{Float64, Int64}
β”‚   %3 = Core.tuple(%1, %2)::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
└──      return %3

If you ask how the compiler may know that >(::Int, ::Int) isn’t a bit random, are there instances where it may be? I would have imagined that in cases like these, it may assume purity. But I don’t know much about compilers, which is why I’m wondering if there’s a way to achieve this.

With constant propagation of the 2 it already works:

julia> fs(x::Int) = x > 2 ? (1,1) : (1.0, 1.0);

julia> f(x, fs, i) = fs(x)[i];

julia> g(x, fs) = (f(x, fs, 1), f(x, fs, 2));

julia> h() = g(2, fs);

julia> @code_warntype h()
MethodInstance for h()
  from h() @ Main REPL[4]:1
Arguments
  #self#::Core.Const(h)
Body::Tuple{Float64, Float64}
1 ─ %1 = Main.g(2, Main.fs)::Core.Const((1.0, 1.0))
└──      return %1

(I thought it might have been necesarry to force specialization on the function with ::F) where F but doesn’t seem like it is needed.)

1 Like

But for a mutable input value, this fails again:

julia> fs(x::Vector) = x[1] > 2 ? (1,1) : (1.0, 1.0);

julia> h() = g([2], fs);

julia> @code_warntype h()
MethodInstance for h()
  from h() @ Main REPL[12]:1
Arguments
  #self#::Core.Const(h)
Body::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
1 ─ %1 = Base.vect(2)::Vector{Int64}
β”‚   %2 = Main.g(%1, Main.fs)::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
└──      return %2

So may not be a red herring after all?

Still, without a constant value for x it doesn’t seem to work. Specializing the function fs with F<:Function doesn’t help either.

Add a function inference barrier perhaps?

julia> g(A, fs) = let infer(a::T,b::T) where T = (a,b)
           infer(f(A, fs, 1), f(A, fs, 2))
       end
g (generic function with 1 method)

julia> @code_warntype g(ones(1,1), fs)
MethodInstance for g(::Matrix{Float64}, ::typeof(fs))
  from g(A, fs) @ Main REPL[133]:1
Arguments
  #self#::Core.Const(g)
  A::Matrix{Float64}
  fs::Core.Const(fs)
Locals
  infer::var"#infer#164"
Body::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
1 ─      (infer = %new(Main.:(var"#infer#164")))
β”‚   %2 = Main.f(A, fs, 1)::Union{Float64, Int64}
β”‚   %3 = Main.f(A, fs, 2)::Union{Float64, Int64}
β”‚   %4 = (infer)(%2, %3)::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
└──      return %4

This is similar to adding type-assertions to g, which I can’t do. I expect the types to be the same for this fs that I’m providing, but they may not be identical in general.

Does that mean this is a problem for multiple dispatch?

julia> g(A, any_fs) = (f(A, any_fs, 1), f(A, any_fs, 2))  # generic
g (generic function with 1 method)

julia> g(A, ::typeof(fs)) = let infer(a::T,b::T) where T = (a,b)
           infer(f(A, fs, 1), f(A, fs, 2))  
       end  # specialization to your fs
g (generic function with 2 methods)

trial:

julia> fs(A::AbstractMatrix) = isdiag(A) ? (1,1) : (1.0, 1.0)
fs (generic function with 1 method)

julia> some_fs(A::AbstractMatrix) = isdiag(A) ? (1,1) : (1.0, 1.0)
some_fs (generic function with 1 method)

julia> @code_warntype g(ones(1,1), fs)
MethodInstance for g(::Matrix{Float64}, ::typeof(fs))
  from g(A, ::typeof(fs)) @ Main REPL[141]:1
Arguments
  #self#::Core.Const(g)
  A::Matrix{Float64}
  _::Core.Const(fs)
Locals
  infer::var"#infer#168"
Body::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
1 ─      (infer = %new(Main.:(var"#infer#168")))
β”‚   %2 = Main.f(A, Main.fs, 1)::Union{Float64, Int64}
β”‚   %3 = Main.f(A, Main.fs, 2)::Union{Float64, Int64}
β”‚   %4 = (infer)(%2, %3)::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
└──      return %4


julia> @code_warntype g(ones(1,1), some_fs)
MethodInstance for g(::Matrix{Float64}, ::typeof(some_fs))
  from g(A, any_fs) @ Main REPL[140]:1
Arguments
  #self#::Core.Const(g)
  A::Matrix{Float64}
  any_fs::Core.Const(some_fs)
Body::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
1 ─ %1 = Main.f(A, any_fs, 1)::Union{Float64, Int64}
β”‚   %2 = Main.f(A, any_fs, 2)::Union{Float64, Int64}
β”‚   %3 = Core.tuple(%1, %2)::Tuple{Union{Float64, Int64}, Union{Float64, Int64}}
└──      return %3

Yes, specializing this way definitely works, but it’s really a last resort :slight_smile: I was hoping that this may be handled at a lower level

Seems to really be mostly offended by f(...,fs,1) and f(...,fs,2) since 1 and 2 are in the value domain after all and not in the type domain.

This seems to work:

julia> g(x,fs)=(fs(x)...,)
julia> @code_warntype g(2, fs)
MethodInstance for g(::Int64, ::typeof(fs))
  from g(x, fs) in Main at REPL[15]:1
Arguments
  #self#::Core.Const(g)
  x::Int64
  fs::Core.Const(fs)
Body::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
1 ─ %1 = (fs)(x)::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
β”‚   %2 = Core._apply_iterate(Base.iterate, Core.tuple, %1)::Union{Tuple{Float64, Float64}, Tuple{Int64, Int64}}
└──      return %2
1 Like

It would be quite difficult for it to realize those are the same expression, with no mutation in between (aka permitting invariant hoisting and then common subexpression elimination or CSE). You can try to assert that one isa typeof(other), but we have no support at all for that in inference (dependency or reverse dataflow constraints even just between 2 variables are somewhat shockingly hard to integrate into the analysis)

3 Likes