Type inference in recursion


#1

I am transitioning quantum chemistry code from Matlab to Julia 0.5.0. Great experience so far!

I am having trouble with type instability in one of the performance-critical functions. The function is recursive and involves 6 recursion relationships with 6 terms each. When terminating the recursion, it returns values that are pre-computed outside the recursion function.

Here is a minimal example (boiled down to one recursion relation with a single term, and no math) that shows the same type inference behavior:

function outerfun0()
 
val = 1.0  # expensive to calculate

  function recur(a)
    if a>0
      return recur(a-1)
    else
      return val
    end
  end

@code_warntype recur(2)

end

The output is

Variables:
  #self#::#recur#25{Float64}
  a::Int64

Body:
  begin 
      unless (Base.slt_int)(0,a::Int64)::Bool goto 4 # line 8:
      return ((Core.getfield)((Core.getfield)(#self#::#recur#25{Float64},:recur)::CORE.BOX,:contents)::ANY)((Base.box)(Int64,(Base.sub_int)(a::Int64,1)))::ANY
      4:  # line 10:
      return (Core.getfield)(#self#::#recur#25{Float64},:val)::Float64
  end::ANY

I do not understand why there is ::ANY and ::CORE.BOX, which I think indicate a type instability related to the return type of recur. Julia seems to correctly infer types for a (Int64) and for val (Float64). All the code paths through recur return a Float64. So everything appears to be type-stable.

Interestingly, if I remove val in the inner function, the ::CORE.BOX disappears, but ::ANY stays:

function outerfun1()

  val = 1.0  # expensive to calculate

  function recur(a)
    if a>0
      return recur(a-1)
    else
      return 1.0
    end
  end

@code_warntype recur(2)

end

Output:

Variables:
  #self#::#recur#30
  a::Int64

Body:
  begin 
      unless (Base.slt_int)(0,a::Int64)::Bool goto 4 # line 24:
      return ((Core.getfield)((Core.getfield)(#self#::#recur#30,:recur)::ANY,:contents)::ANY)((Base.box)(Int64,(Base.sub_int)(a::Int64,1)))::ANY
      4:  # line 26:
      return 1.0
  end::ANY

How to get rid of this type instability? I can’t find a way. Neither type-annotating the return type of ‘recur’ nor wrapping both return expressions of recur with Float64() appears to help.


#2

This is caused by the closure. Do not use the outerfun1 and make val a global constant.


#3

Is this https://github.com/JuliaLang/julia/issues/15276?


#4

Yes <20 characters limit>


#5

Got it, thanks!

Unfortunately, I cannot get rid of outerfun0(), as it is needed by several other functions, where it is called from inside nested for loops. Any other strategy to improve performance of this?


#6

Any method that removes the use of closure. You can just move recur out.


#7

Why don’t you just pass val into the recur?

julia> function recur(a, val)
           if a>0
               return recur(a-1, val)
           else
               return val
           end
       end

julia> function outerfun0()
           val = 1.0  # expensive to calculate
           @code_warntype recur(2, val)
       end
outerfun0 (generic function with 2 methods)

julia> outerfun0()
Variables:
  #self#::#recur
  a::Int64
  val::Float64

Body:
  begin 
      unless (Base.slt_int)(0,a::Int64)::Bool goto 4 # line 3:
      return $(Expr(:invoke, LambdaInfo for recur(::Int64, ::Float64), :(Main.recur), :((Base.box)(Int64,(Base.sub_int)(a,1))), :(val)))
      4:  # line 5:
      return val::Float64
  end::Float64



#8

This is not an option in more complex situations.


#9

Thanks for the suggestion - that was what I was trying to avoid.

In my full code, val is a set of 5 variables that are constant during the recursion, and a is a set of 7 variables that changes in the recursion. Conceptually, it appears cleanest to me to explicitly pass the a set (I have to, anyway), but to use the closure mechanism to avoid having to explicitly pass the val set. In the full code, recur has six 6-term recursions, so there are a total of 36 recur calls. Passing the val set changes the type of these calls from

recur(ax,ab,az-1,bx,by,bz-1,m+1)

to

recur(ax,ay,az-1,bx,by,bz-1,m+1,PA,PB,CP,gab,ss_m)

reducing readability substantially.

As a simple analogy, it is much cleaner to write things like f(n) = c*f(n-1) + d*f(n-2) and not f(n,c,d) = c*f(n-1,c,d) + d*f(n-2,c,d).

I was hoping that the performance hit from using var via a closure vs. passing it explicitly would not be too great. Unfortunately, though, the closure version is 2x slower and allocates 2x the memory.


#10

Why can’t you make a type to contain your 6 variables?

type RecurType
PA
PB
...
end

#11

That works. It is not quite as fast as the version that explicitly passes everything individually, but cleans up the code quite a bit. It might be the best compromise with the current version of Julia.


#12

EDIT: I read that val is constant: maybe you should make your type immutable

immutable RecurType
PA
PB
...
end

#13

Why do you say that? If you mean because there might be lots of info to pass around, you can wrap it in a State type.


#14

Because it is defeating the purpose of closures. One should not need to define a new type whenever one wants a fast closure which captures some variables of its environment.


#15

That makes sense and it seems to improve performance by 5% in the full code, without affecting memory allocations. Not sure what the reason for the performance gain is though, since the @code_warntype outputs for both the type and the immutable versions are identical.

It would be nice if closures could be used for this without performance penalty, instead of having to pass a state variable down to the inner nested function.