Make Julia complete the inference of some recursive code

Background

I have a function with a real number input coming from a finite interval. To implement the function, I break down that interval into several subintervals and implement the function on each of them. Suppose, for example, the top-level function is f and depends on four other functions, each for a different subinterval: f0, f1, f2 and f3. Suppose we have the interval \left[0.1, 0.9\right], and the subintervals \left[0.1, 0.3\right], \left[0.3, 0.5\right], \left[0.5, 0.7\right], \left[0.7, 0.9\right]. To implement the top-level function efficiently, it’s common to use a hard-coded binary search like so:

function f(x)
  if x < 0.5
    if x < 0.3
      f0(x)
    else
      f1(x)
    end
  else
    if x < 0.7
      f2(x)
    else
      f3(x)
    end
  end
end

While this is OK with a small subinterval count, it becomes ugly and difficult to manage with a larger number of subintervals. An alternative to hardcoding the binary search is constructing a binary search tree variant explicitly. Taking inspiration from range trees, here’s a nice recursive construction:

  • each leaf stores one subinterval

  • each inner node stores two subtrees, each representing an interval, and the boundary between the two intervals

I want performance that’s as good as with the hardcoded binary search, so my idea was to put the search tree into the type domain and then recur on it to perform the search. This is my code:

module IntervalRangeTrees

struct Leaf{T}
  l::T
  r::T
end

struct NonLeaf{S,T}
  l::S
  r::S
  b::T
end

is_leaf(::NonLeaf) = false
is_leaf(::Leaf   ) = true

function recur(
  f::F, p::P, tree::Val{Tree},
) where {F, P, Tree}
  if is_leaf(Tree)
    f(tree)
  else
    if p(Tree.b)
      recur(f, p, Val(Tree.l))
    else
      recur(f, p, Val(Tree.r))
    end
  end
end

function construct_impl(
  t::(Tuple{T,T,Vararg{T}} where {T}),
  ::Val{I},
  ::Val{I},
) where {I}
  I::Int
  l = t[begin + I]
  r = t[begin + I + 1]
  Leaf(l, r)
end

struct Exc <: Exception end

Base.@assume_effects :foldable function construct_impl(
  t::(Tuple{T,T,Vararg{T}} where {T}),
  ::Val{I},
  ::Val{J},
) where {I, J}
  I::Int
  J::Int
  (false ≤ I < J < (length(t) - 1)) || throw(Exc())
  n = J - I + 1
  ispow2(n) || throw(Exc())
  m = I + (n >>> true)
  l = construct_impl(t, Val(I), Val(m - 1))
  r = construct_impl(t, Val(m), Val(I + n - 1))
  b = t[begin + m]
  NonLeaf(l, r, b)
end

function construct(t::(Tuple{T,T,Vararg{T}} where {T}))
  construct_impl(t, Val(0), Val(length(t) - 2))
end

end

The construct function constructs a search tree and recur performs the search. As a usage example, here’s a reimplementation of the function f above:

g(x, ::Val{(0.1, 0.3)}) = f0(x)
g(x, ::Val{(0.3, 0.5)}) = f1(x)
g(x, ::Val{(0.5, 0.7)}) = f2(x)
g(x, ::Val{(0.7, 0.9)}) = f3(x)

function f(x)
  h = let x = x
    function(::Val{T}) where {T}
      g(x, Val{(T.l, T.r)}())
    end
  end

  p = let x = x
    y -> x < y
  end

  tree = IntervalRangeTrees.construct((0.1, 0.3, 0.5, 0.7, 0.9))
  tree_v = Val{tree}()
  IntervalRangeTrees.recur(h, p, tree_v)
end

Problem: Julia gives up on inference

julia> f0(x) = 1*x
f0 (generic function with 1 method)

julia> f1(x) = 10*x
f1 (generic function with 1 method)

julia> f2(x) = 100*x
f2 (generic function with 1 method)

julia> f3(x) = 1000*x
f3 (generic function with 1 method)

julia> using Test

julia> @inferred f(0.2)
ERROR: return type Float64 does not match inferred return type Any

JET.jl reports that Julia “failed to optimize due to recursion”:

julia> using JET

julia> @report_opt f(0.2)
═════ 2 possible errors found ═════
┌ f(x::Float64) @ Main ./REPL[6]:14
│┌ recur([...]) @ Main.IntervalRangeTrees ./REPL[1]:24
││┌ recur([...]) @ Main.IntervalRangeTrees ./REPL[1]:17
│││ failed to optimize due to recursion: Main.IntervalRangeTrees.recur([...])
││└────────────────────
│┌ recur([...]) @ Main.IntervalRangeTrees ./REPL[1]:26
││┌ recur([...]) @ Main.IntervalRangeTrees ./REPL[1]:17
│││ failed to optimize due to recursion: Main.IntervalRangeTrees.recur([...])
││└────────────────────

JET’s docs say that the Julia compiler gives up on optimization when (ref):

there are (mutually) recursive calls and Julia compiler decided not to do inference in order to make sure the inference’s termination. In such a case, optimization won’t happen and method dispatches aren’t resolved statically

Relevant Julia discussion (@aplavin), but no responses as of yet: Allow more aggressive inference for some functions · JuliaLang/julia · Discussion #52242 · GitHub

IMO it’s quite silly that Julia fails in recur here, the code seems simple and the recursion is shallow (should be known at compile time, I guess). I’ll open an issue on Github later, but is it possible to adjust my code so it would be performant and infer perfectly on current versions of Julia?

1 Like

I tried some obvious quick fixes (@inline and Base.@constprop :aggressive) but these did not work.

The next thing that comes to my mind is a generated function which works well and is also just a smallish change in the code:

@generated function recur(f, p, tree::Val{Tree}) where Tree
    function _recur_expr(tree)
        if is_leaf(tree)
            return :(f($(Val(tree))))
        else
            return quote
                if p($(tree.b))
                    $(_recur_expr(tree.l))
                else
                    $(_recur_expr(tree.r))
                end
            end
        end
    end
    _recur_expr(Tree)
end

I think it is reasonable to use a generated function for this. After all the purpose of all this type-domain-computing-stuff is in the end to (reliably!) generate an efficient function.

I defined _recur_expr whithin the generated function because otherwise it didn’t work. Per the manual it should work to define it before defining recur which was the case in my source file but Revise.jl likely messed that up. That’s why I put it inside of recur

Side remark: You do a lot of wrapping/unwrapping in Val types which personally I find a bit unnecessary and messy. You could use the Intervals, Leafs and NonLeafs directly as types without wrapping into Val. Maybe this could shorten compile times in more complicated examples as well?

4 Likes

Might be related to RFC: Less aggressive recursion limiting by Keno · Pull Request #48059 · JuliaLang/julia · GitHub. I’m not sure because the graph stuff in there still goes over my head, which might mean I should’ve studied this a bit harder Inference Convergence Algorithm in Julia (juliahub.com) a while back when I wondered how self-recursive and mutually recursive methods with seemingly any call graph cycle size were even inferred at all. Even if I didn’t entirely misconnect the dots, I don’t know if the linked pull request would alter the heuristic enough to make this particular example of recursively multiplying call signatures work.

1 Like