# Type stability for piecewise defined functions with expensive parts

The function

``````function f(x)
if x < 0
return 0
end
return expensive_function(x)
end
``````

Is not type-stable (in most cases). Is there a clean way to make it stable and generic, while staying performant for `x < 0 `? The most straightforward thing I see would be to replace `return 0` with

``````return zero(expensive_function(x))
``````

but this does evaluate `expensive_function(x)` as far as I can tell.
I saw in this question, that it is not easy to get the return type of a function without evaluating it. Now I donâ€™t want to use the type in the code, I just want `f` to be type stable and performant for `x < 0`. Is this possible, assuming `expensive_function(x)` is type-stable?

What happens if you return `nothing` instead of zero?

1 Like

+1, return `nothing` and union splitting optimization should kick in (Union-splitting: what it is, and why you should care).

1 Like

Interesting, thanks for the answers! What I had in mind though was `f` being used in a numerical context, e.g. like

``````sum( f(x) for x in large_array )
``````

where returning `nothing` raises an error. Does union splitting optimization also help, if the return type is e.g. `Union{Int64, Float64}`?

That helps, but isnâ€™t the type of result of `expensive_function(x)` in that case deducible from the type of `x `? In that case you could use:

``````function f(x)
if x < 0
return zero(typeof(x))
end
return expensive_function(x)
end
``````

Example:

``````julia> expensive_function(x) = sum(x .+ [1,2,3,4]) # may return Int or Float
function f(x::T) where T
if x < 0
return zero(T)
else
return expensive_function(x)
end
end
f (generic function with 1 method)

julia> @code_warntype f(1)
Variables
#self#::Core.Const(f)
x::Int64

Body::Int64
1 â”€ %1 = (x < 0)::Bool
â””â”€â”€      goto #3 if not %1
2 â”€ %3 = Main.zero(\$(Expr(:static_parameter, 1)))::Core.Const(0)
â””â”€â”€      return %3
3 â”€ %5 = Main.expensive_function(x)::Int64
â””â”€â”€      return %5

julia> @code_warntype f(1.0)
Variables
#self#::Core.Const(f)
x::Float64

Body::Float64
1 â”€ %1 = (x < 0)::Bool
â””â”€â”€      goto #3 if not %1
2 â”€ %3 = Main.zero(\$(Expr(:static_parameter, 1)))::Core.Const(0.0)
â””â”€â”€      return %3
3 â”€ %5 = Main.expensive_function(x)::Float64
â””â”€â”€      return %5

``````

Here are some benchmarks:

``````julia> using BenchmarkTools

julia> u_union = Union{Int, Float64}[rand(Bool) ? rand() : rand(1:10) for i in 1:10^6];

julia> u_real = identity.(u_union);

julia> u_float = float.(u_union);

julia> @show typeof(u_union) typeof(u_real) typeof(u_float);
typeof(u_union) = Vector{Union{Float64, Int64}}
typeof(u_real) = Vector{Real}
typeof(u_float) = Vector{Float64}

julia> @btime sum(u_union); @btime sum(u_real); @btime sum(u_float);
3.771 ms (1 allocation: 16 bytes)
18.310 ms (999503 allocations: 15.25 MiB)
125.891 ÎĽs (1 allocation: 16 bytes)
``````

So having a union type helps but if you do something that can be so optimized like a sum, then it will fall far behind of having a concrete type.

You can use e.g. `filter` syntax:

``````julia> x = [1, 2, nothing, 3]
4-element Vector{Union{Nothing, Int64}}:
1
2
nothing
3

julia> sum(x for x in a if x !== nothing)
6
``````

But if you really need to get a zero of the same type as `expensive_function` then you need some way to calculate the resulting element type yourself.

2 Likes

Thanks for the benchmark!

But if you really need to get a zero of the same type as `expensive_function` then you need some way to calculate the resulting element type yourself.

Could I do this with a generated function? Something like:

``````@generated function f(x)
expensive_type = typeof(expensive_function(zero(x)))
quote
if x < 0
return zero(\$expensive_type)
end
return expensive_function(x)
end
end
``````

Or is this bound to break in some way?

That is not a good idea. See the docs (Metaprogramming Â· The Julia Language) for some things you shouldnâ€™t do.

See the docs (Metaprogramming Â· The Julia Language ) for some things you shouldnâ€™t do.

I did that before posting because I was indeed suspicious and I know generated functions should be used sparingly and with caution, but I couldnâ€™t put my finger on any concrete problem with the code.

I thought about if it observes any global mutable state, but it just calls some functions and the docs just mention â€śCalling any function that is defined after the body of the generated functionâ€ť as bad, which to me sounds like calling other functions is ok.

So as far as I can see, the code does nothing that the docs forbid. Or am I missing something?

1 Like