Inferring return type of a method

There is probably a simpler example but this is the essence of my problem:

abstract type Abstract end
struct Concrete1 <: Abstract end
struct Concrete2 <: Abstract end
struct Thing a::Abstract end

(::Concrete1)(x)= 1+x
(::Concrete2)(x)= 2+x

foo(thing::Thing) = thing.a(3)

mything = Thing(Concrete1())

@code_warntype foo(mything)

The code works as expected but the compiler can’t infer the return type of foo, because the exact type of thing.a is not known.

In practice all my concrete types will return an Int64 for Int64 arguments, so is there any way to tell Julia that? ie for any T<:Abstract always infer that T(::Int64) returns an Int64

Welcome to the Julia Discourse!

I think that the issue here is that you have a struct (Thing, in this case) in which one of the fields does not have a concrete type. Specifically, a is of type Abstract.

You can fix this by making Thing a parametrized struct:

struct Thing{T<:Abstract}
    a::T 
end

Here’s a full working example. As you can see, Julia is able to infer the return type of foo:

julia> abstract type Abstract end

julia> struct Concrete1 <: Abstract end

julia> struct Concrete2 <: Abstract end

julia> struct Thing{T<:Abstract}
       a::T
       end

julia> (::Concrete1)(x::T) where T = T(1) + x

julia> (::Concrete2)(x::T) where T = T(2) + x

julia> foo(thing::Thing) = thing.a(3)
foo (generic function with 1 method)

julia> mything = Thing(Concrete1())
Thing{Concrete1}(Concrete1())

julia> @code_warntype foo(mything)
Variables
  #self#::Core.Compiler.Const(foo, false)
  thing::Core.Compiler.Const(Thing{Concrete1}(Concrete1()), false)

Body::Int64
1 ─ %1 = Base.getproperty(thing, :a)::Core.Compiler.Const(Concrete1(), false)
│   %2 = (%1)(3)::Core.Compiler.Const(4, false)
└──      return %2

julia> using Test

julia> @inferred foo(mything)
4
3 Likes

Also, instead of using literals such as 1, and 2, you might consider using T(1), and T(2) to ensure that those are of the same type as x. For example:

(::Concrete1)(x::T) where T = T(1) + x
(::Concrete2)(x::T) where T = T(2) + x
1 Like

Thanks! Seems obvious now you point it out.