Background
I have a function with a real number input coming from a finite interval. To implement the function, I break down that interval into several subintervals and implement the function on each of them. Suppose, for example, the top-level function is f
and depends on four other functions, each for a different subinterval: f0
, f1
, f2
and f3
. Suppose we have the interval \left[0.1, 0.9\right], and the subintervals \left[0.1, 0.3\right], \left[0.3, 0.5\right], \left[0.5, 0.7\right], \left[0.7, 0.9\right]. To implement the top-level function efficiently, it’s common to use a hard-coded binary search like so:
function f(x)
if x < 0.5
if x < 0.3
f0(x)
else
f1(x)
end
else
if x < 0.7
f2(x)
else
f3(x)
end
end
end
While this is OK with a small subinterval count, it becomes ugly and difficult to manage with a larger number of subintervals. An alternative to hardcoding the binary search is constructing a binary search tree variant explicitly. Taking inspiration from range trees, here’s a nice recursive construction:
-
each leaf stores one subinterval
-
each inner node stores two subtrees, each representing an interval, and the boundary between the two intervals
I want performance that’s as good as with the hardcoded binary search, so my idea was to put the search tree into the type domain and then recur on it to perform the search. This is my code:
module IntervalRangeTrees
struct Leaf{T}
l::T
r::T
end
struct NonLeaf{S,T}
l::S
r::S
b::T
end
is_leaf(::NonLeaf) = false
is_leaf(::Leaf ) = true
function recur(
f::F, p::P, tree::Val{Tree},
) where {F, P, Tree}
if is_leaf(Tree)
f(tree)
else
if p(Tree.b)
recur(f, p, Val(Tree.l))
else
recur(f, p, Val(Tree.r))
end
end
end
function construct_impl(
t::(Tuple{T,T,Vararg{T}} where {T}),
::Val{I},
::Val{I},
) where {I}
I::Int
l = t[begin + I]
r = t[begin + I + 1]
Leaf(l, r)
end
struct Exc <: Exception end
Base.@assume_effects :foldable function construct_impl(
t::(Tuple{T,T,Vararg{T}} where {T}),
::Val{I},
::Val{J},
) where {I, J}
I::Int
J::Int
(false ≤ I < J < (length(t) - 1)) || throw(Exc())
n = J - I + 1
ispow2(n) || throw(Exc())
m = I + (n >>> true)
l = construct_impl(t, Val(I), Val(m - 1))
r = construct_impl(t, Val(m), Val(I + n - 1))
b = t[begin + m]
NonLeaf(l, r, b)
end
function construct(t::(Tuple{T,T,Vararg{T}} where {T}))
construct_impl(t, Val(0), Val(length(t) - 2))
end
end
The construct
function constructs a search tree and recur
performs the search. As a usage example, here’s a reimplementation of the function f
above:
g(x, ::Val{(0.1, 0.3)}) = f0(x)
g(x, ::Val{(0.3, 0.5)}) = f1(x)
g(x, ::Val{(0.5, 0.7)}) = f2(x)
g(x, ::Val{(0.7, 0.9)}) = f3(x)
function f(x)
h = let x = x
function(::Val{T}) where {T}
g(x, Val{(T.l, T.r)}())
end
end
p = let x = x
y -> x < y
end
tree = IntervalRangeTrees.construct((0.1, 0.3, 0.5, 0.7, 0.9))
tree_v = Val{tree}()
IntervalRangeTrees.recur(h, p, tree_v)
end
Problem: Julia gives up on inference
julia> f0(x) = 1*x
f0 (generic function with 1 method)
julia> f1(x) = 10*x
f1 (generic function with 1 method)
julia> f2(x) = 100*x
f2 (generic function with 1 method)
julia> f3(x) = 1000*x
f3 (generic function with 1 method)
julia> using Test
julia> @inferred f(0.2)
ERROR: return type Float64 does not match inferred return type Any
JET.jl reports that Julia “failed to optimize due to recursion”:
julia> using JET
julia> @report_opt f(0.2)
═════ 2 possible errors found ═════
┌ f(x::Float64) @ Main ./REPL[6]:14
│┌ recur([...]) @ Main.IntervalRangeTrees ./REPL[1]:24
││┌ recur([...]) @ Main.IntervalRangeTrees ./REPL[1]:17
│││ failed to optimize due to recursion: Main.IntervalRangeTrees.recur([...])
││└────────────────────
│┌ recur([...]) @ Main.IntervalRangeTrees ./REPL[1]:26
││┌ recur([...]) @ Main.IntervalRangeTrees ./REPL[1]:17
│││ failed to optimize due to recursion: Main.IntervalRangeTrees.recur([...])
││└────────────────────
JET’s docs say that the Julia compiler gives up on optimization when (ref):
there are (mutually) recursive calls and Julia compiler decided not to do inference in order to make sure the inference’s termination. In such a case, optimization won’t happen and method dispatches aren’t resolved statically
Relevant Julia discussion (@aplavin), but no responses as of yet: Allow more aggressive inference for some functions · JuliaLang/julia · Discussion #52242 · GitHub
IMO it’s quite silly that Julia fails in recur
here, the code seems simple and the recursion is shallow (should be known at compile time, I guess). I’ll open an issue on Github later, but is it possible to adjust my code so it would be performant and infer perfectly on current versions of Julia?