Getting the type of `f(x)` at compile time (without evaluating `f(x)`), for type-stable functions

I have a type-stable function f(x) that is expensive to evaluate. Given an argument x, how do I obtain the type of the result f(x), but without actually evaluating f(x)?

Ideally this should be elided during compilation and there should be no computation at all, because the type of f(x) is known at compile time.

1 Like
julia> Core.Compiler.return_type(gcd, Tuple{Int,Int})
Int64

This should generally behave the way that you want, but be aware that it generally has the liberty to return Any, return different results, etc. If there’s any other way to do what you want it’s probably a good idea.

julia> f(x) = x
f (generic function with 1 method)

julia> @code_typed(f(1)).second
Int64

julia> @code_typed(f(1.)).second
Float64

Perhaps?

I second this. These “other ways” are documented mostly in the source of Julia and the standard libraries though. A more specific question would allow suggesting the right one.

Eg if you are collecting things, I would recommend you look at the functions collect calls, or a functional approach I experimented with.

@Tamas_Papp Specifically I am looking at this line of code:

https://github.com/JuliaStats/StatsFuns.jl/blob/ae40bc1c6ee63a5624ba6cd168da5bf43f473e59/src/basicfuns.jl#L66

In the last branch, oftype(exp(-x), x), I would like to avoid calling exp, but still get the correct return type.

Generally in a case like this I would just write x/1.

How can you be sure this always has the same type as exp(x)? Consider that someone can write new methods for exp for arbitrary types <: Real.

This concept does not exist in julia.

Core.Compiler.return_type(gcd, Tuple{Int,Int}) is the only function that is useful for this in the very few case you are allowed to use it. Don’t use code_typed.

2 Likes

An alternative solution to the original problem (https://github.com/JuliaStats/StatsFuns.jl/blob/ae40bc1c6ee63a5624ba6cd168da5bf43f473e59/src/basicfuns.jl#L66) is to define:

log1pexp(x::AbstractFloat) = x < 18.0 ? log1p(exp(x)) : x < 33.3 ? x + exp(-x) : x
log1pexp(x::Real) = log1pexp(float(x))

But this stackoverflows when passed dual numbers:

julia> ForwardDiff.derivative(log1pexp, 1.)
ERROR: StackOverflowError:
Stacktrace:
 [1] log1pexp(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(log1pexp),Float64},Float64,1}) at ./REPL[1]:1 (repeats 80000 times)

In general, you don’t. But it’s reasonable to expect that exp(x), x/1 and float(x) will have the same type, and if someone defines exp to return a string or something then I say that’s on them.

(In any case, the worst that happens here is that you get a union, which is no longer that big of a deal for performance anyway.)

1 Like