Help: can anyone tell me why julia is not correctly inferring the return type?


#1

Hi guys!

I have this code here:

import Base: getindex

mutable struct A{T<:Real}
    a::T
    b::T
    c::T
    d::T
end

function getindex(a::A{T}, ::Colon) where T<:Real
    [a.a;a.b;a.c;a.d]
end

function test(a::A{T1}, w::Vector{T2}) where T1<:Real where T2<:Real
    # Check the dimensions.
    if length(w) != 3
        throw(ArgumentError)
    end

    Ω = Array{T2}([  0   -w[1] -w[2] -w[3] ;
                   +w[1]   0   +w[3] -w[2] ;
                   +w[2] -w[3]   0   +w[1] ;
                   +w[3] +w[2] -w[1]   0   ])

    # Return the time-derivative.
    (Ω/2)*a[:]
end

However, I saw that julia cannot determine the correct return type of the function test:

using Base.Tests
@inferred test(A(1.,2.,3.,4.),[0.0;1.0;0.0])
ERROR: return type Array{Float64,1} does not match inferred return type Any
Stacktrace:
 [1] error(::String) at ./error.jl:21

I know this can significantly decrease the performance of the code. Can anyone tell me why this is happening and if there is something I can do to fix it?


#2

It’s because you’re using the Array{T2} constructor, which is an incomplete type.

julia> Array{Float64}
Array{Float64,N} where N

If you change it to

function test(a::A{T1}, w::Vector{T2}) where T1<:Real where T2<:Real
           # Check the dimensions.
           if length(w) != 3
               throw(ArgumentError)
           end

           Ω = Matrix{T2}([  0   -w[1] -w[2] -w[3] ;
                          +w[1]   0   +w[3] -w[2] ;
                          +w[2] -w[3]   0   +w[1] ;
                          +w[3] +w[2] -w[1]   0   ])

           # Return the time-derivative.
           (Ω/2)*a[:]
       end

instead, things get correctly inferred.


#3

Or just T2[ .... ; ... ; ...], which I think is the canonical way to construct a Matrix of given type.

Also, shouldn’t you be throwing an ArgumentError("description") instead the type ArgumentError.


#4

Also, this probably won’t matter for your use case as I imagine the promotions are working fine, but the 0's inside of Ω should be zero(T2) rather than Int 0s. It’s just a good general practice to have, you can occasionally run in to trouble otherwise.


#5

Excelent! Thanks :slight_smile:


#6

Yes, indeed! This is a part of a bigger project that I am doing related to 3D rotations and kinematics for satellite simulations (I will very soon post something here about it). In this code, I am showing a message that the input vector must have 3 components.

And thanks for the info about T2. The code becomes much more cleaner :slight_smile:


#7

Thanks, do you mean something like this:

    Ω = T2[  zero(T2)  -w[1]    -w[2]    -w[3] ;
              +w[1]   zero(T2)  +w[3]    -w[2] ;
              +w[2]    -w[3]   zero(T2)  +w[1] ;
              +w[3]    +w[2]    -w[1]   zero(T2)   ]

P.S.: Sorry for the number of replies. Being used to bugzilla, I did not realize how easy is to reply everyone at once :slight_smile:


#8

Yes, exactly.


#9

No. Using Array{T2} is completely fine. The issue is the literal. If you look at the code_warntype

julia> @code_warntype test(A(1., 2., 3., 4.), [0.0; 1.0; 0.0])
Variables:
  a::A{Float64}
  w::Array{Float64,1}
  Ω::Any

Body:
  begin
      NewvarNode(:(Ω::Any))
      unless (Base.not_int)(((Base.arraylen)(w::Array{Float64,1})::Int64 === 3)::Bool)::Bool goto 5
      #= line 4 =#
      (Main.throw)(Main.ArgumentError)::Union{}
      5: 
      #= line 7 =#
      $(Expr(:static_parameter, 1))
      SSAValue(0) = (Core.tuple)(0, (Base.neg_float)((Base.arrayref)(true, w::Array{Float64,1}, 1)::Float64)::Float64, (Base.neg_float)((Base.arrayref)(true, w::Array{Float64,1}, 2)::Float64)::Float64, (Base.neg_float)((Base.arrayref)(true, w::Array{Float64,1}, 3)::Float64)::Float64, (Base.arrayref)(true, w::Array{Float64,1}, 1)::Float64, 0, (Base.arrayref)(true, w::Array{Float64,1}, 3)::Float64, (Base.neg_float)((Base.arrayref)(true, w::Array{Float64,1}, 2)::Float64)::Float64, (Base.arrayref)(true, w::Array{Float64,1}, 2)::Float64, (Base.neg_float)((Base.arrayref)(true, w::Array{Float64,1}, 3)::Float64)::Float64, 0, (Base.arrayref)(true, w::Array{Float64,1}, 1)::Float64, (Base.arrayref)(true, w::Array{Float64,1}, 3)::Float64, (Base.arrayref)(true, w::Array{Float64,1}, 2)::Float64, (Base.neg_float)((Base.arrayref)(true, w::Array{Float64,1}, 1)::Float64)::Float64, 0)::Tuple{Int64,Float64,Float64,Float64,Float64,Int64,Float64,Float64,Float64,Float64,Int64,Float64,Float64,Float64,Float64,Int64}
      SSAValue(1) = (Core._apply)(Base.promote_typeof, SSAValue(0))::Any
      Ω::Any = (Array{Float64,N} where N)((Base.typed_hvcat)(SSAValue(1), (4, 4, 4, 4), (Core.getfield)(SSAValue(0), 1)::Int64, (Core.getfield)(SSAValue(0), 2)::Float64, (Core.getfield)(SSAValue(0), 3)::Float64, (Core.getfield)(SSAValue(0), 4)::Float64, (Core.getfield)(SSAValue(0), 5)::Float64, (Core.getfield)(SSAValue(0), 6)::Int64, (Core.getfield)(SSAValue(0), 7)::Float64, (Core.getfield)(SSAValue(0), 8)::Float64, (Core.getfield)(SSAValue(0), 9)::Float64, (Core.getfield)(SSAValue(0), 10)::Float64, (Core.getfield)(SSAValue(0), 11)::Int64, (Core.getfield)(SSAValue(0), 12)::Float64, (Core.getfield)(SSAValue(0), 13)::Float64, (Core.getfield)(SSAValue(0), 14)::Float64, (Core.getfield)(SSAValue(0), 15)::Float64, (Core.getfield)(SSAValue(0), 16)::Int64)::Any)::Any
      #= line 13 =#
      return ((Ω::Any / 2)::Any * $(Expr(:invoke, MethodInstance for vcat(::Float64, ::Float64, ::Float64, ::Vararg{Float64,N} where N), :(Main.vcat), :((Core.getfield)(a, :a)::Float64), :((Core.getfield)(a, :b)::Float64), :((Core.getfield)(a, :c)::Float64), :((Core.getfield)(a, :d)::Float64)))::Array{Float64,1})::Any
  end::Any

You’ll notice that Any appears BEFORE you calling the array constructor. It’s hitting the tuple size limit here when computing the eltype for the literal.


#10

Sorry the ignorance, but if I, for some reason, want to use the Array{T2}, how can I fix this?


#11

As I said, Array{T2} isn’t the issue here. You can simply call it on,


#12

Or add type assertion on either the argument or the return value.