I am using generated functions in a recursive manner, and I have formed a contrived example which highlights how unnecessary allocations can occur. First, I will show you the efficient case
@generated function Factorial(::Val{N}) where N
N == 1 ? :(1) : :(N * Factorial(Val($(N-1))))
end
julia> @btime Factorial(Val(10))
0.001 ns (0 allocations: 0 bytes)
3628800
The function above calculates the factorial in a classical recursive manner. The function works, and the compiler is able to efficiently optimise away the calculations to return a constant (below).
If my understanding is correct, the compiler is smart enough to ‘unroll’ each level of the generated functions, leaving an expression which is just the factorial spelled out literally. The compiler can then additionally optimise the products of literals to return a constant. Great!
Now here is the problem:
forwardingfunction(a::Val{N}) where N = Factorial(a)
@generated function Factorial(::Val{N}) where N
N == 1 ? :(1) : :(N * forwardingfunction(Val($(N-1))))
end
julia> @btime Factorial(Val(10))
268.189 ns (5 allocations: 80 bytes)
3628800
Note that we have replaced the recursive call Factorial with forwardingfunction. However, the latter function just forwards the argument to Factorial so the functionality is no different to before.
By inserting forwardingfunction within the Factorial, the function is much slower (and allocates!). This suggests that the compiler is resorting to runtime dispatch. However, forwardingfunction does not require any runtime information (i think) and therefore, should be compiled away.
Why is the compiler not eager enough to perform the same sort of optimisations as before?
In my incomplete understanding you are mixing compile and runtime here. I started with maybe a more natural example for mutual recursion at compile time
function Odd(::Val{N}) where N end
@generated function Even(::Val{N}) where N
N == 0 || Odd(Val(N - 1))
end
@generated function Odd(::Val{N}) where N
N == 1 || Even(Val(N - 1))
end
@btime Even(Val(10))
@btime Odd(Val(10))
which adapted to your forwarding requirement would look like
function Factorial(::Val{N}) where N end
@generated function Forward(::Val{N}) where N
:(Factorial(Val(N)))
end
@generated function Factorial(::Val{N}) where N
N == 1 ? 1 : (N * Forward(Val(N-1)))
end
@btime Factorial(Val(10))
Flatten.jl is all nested generated functions, and Accessors.jl has some too. We run into similar problems, stalling generated recursion heavy PRs like this: https://github.com/JuliaObjects/Accessors.jl/pull/23.
We were hoping 1.7 might fix some of the type stability problems, but it has actually made them worse, as you suggest.
What is your reasoning behind turning Forward into a @generated function? I can’t tell the difference between your implementation and mine. If my understanding is correct, you are instantiating a Val each time you run that function, whereas in my implementation I am just forwarding the argument.
The same as yours making Factorial a @generated function in the first place? Shifting computation from runtime to compile time, which it obviously does:
0.001 ns (0 allocations: 0 bytes)
3628800
Edit: there seems to be an alternate route, which could be more appropriate for you:
using BenchmarkTools
Base.@pure function Forward(a::Val{N}) where N
Factorial(a)
end
@generated function Factorial(::Val{N}) where N
N == 1 ? 1 : (N * Forward(Val(N-1)))
end
@btime Factorial(Val(10))
The difference between Factorial and Forward is that Factorial performs computations. I am not performing any computations in the Forward function, I am just forwarding the argument - the compiler does not require any runtime information to compile that.
What computation are you exactly performing in your version of Forward? If I am understanding correctly, everything within the expression you return is being evaluated at runtime.