Type stability for piecewise defined functions with expensive parts

The function

function f(x)
    if x < 0
        return 0
    end
    return expensive_function(x)
end

Is not type-stable (in most cases). Is there a clean way to make it stable and generic, while staying performant for x < 0 ? The most straightforward thing I see would be to replace return 0 with

return zero(expensive_function(x))

but this does evaluate expensive_function(x) as far as I can tell.
I saw in this question, that it is not easy to get the return type of a function without evaluating it. Now I don’t want to use the type in the code, I just want f to be type stable and performant for x < 0. Is this possible, assuming expensive_function(x) is type-stable?

What happens if you return nothing instead of zero?

1 Like

+1, return nothing and union splitting optimization should kick in (Union-splitting: what it is, and why you should care).

1 Like

Interesting, thanks for the answers! What I had in mind though was f being used in a numerical context, e.g. like

sum( f(x) for x in large_array )

where returning nothing raises an error. Does union splitting optimization also help, if the return type is e.g. Union{Int64, Float64}?

That helps, but isn’t the type of result of expensive_function(x) in that case deducible from the type of x ? In that case you could use:

function f(x)
    if x < 0
        return zero(typeof(x))
    end
    return expensive_function(x)
end

Example:

julia> expensive_function(x) = sum(x .+ [1,2,3,4]) # may return Int or Float
       function f(x::T) where T
           if x < 0
               return zero(T)
           else
               return expensive_function(x)
           end
       end
f (generic function with 1 method)

julia> @code_warntype f(1)
Variables
  #self#::Core.Const(f)
  x::Int64

Body::Int64
1 ─ %1 = (x < 0)::Bool
└──      goto #3 if not %1
2 ─ %3 = Main.zero($(Expr(:static_parameter, 1)))::Core.Const(0)
└──      return %3
3 ─ %5 = Main.expensive_function(x)::Int64
└──      return %5

julia> @code_warntype f(1.0)
Variables
  #self#::Core.Const(f)
  x::Float64

Body::Float64
1 ─ %1 = (x < 0)::Bool
└──      goto #3 if not %1
2 ─ %3 = Main.zero($(Expr(:static_parameter, 1)))::Core.Const(0.0)
└──      return %3
3 ─ %5 = Main.expensive_function(x)::Float64
└──      return %5


Here are some benchmarks:

julia> using BenchmarkTools

julia> u_union = Union{Int, Float64}[rand(Bool) ? rand() : rand(1:10) for i in 1:10^6];

julia> u_real = identity.(u_union);

julia> u_float = float.(u_union);

julia> @show typeof(u_union) typeof(u_real) typeof(u_float);
typeof(u_union) = Vector{Union{Float64, Int64}}
typeof(u_real) = Vector{Real}
typeof(u_float) = Vector{Float64}

julia> @btime sum(u_union); @btime sum(u_real); @btime sum(u_float);
  3.771 ms (1 allocation: 16 bytes)
  18.310 ms (999503 allocations: 15.25 MiB)
  125.891 μs (1 allocation: 16 bytes)

So having a union type helps but if you do something that can be so optimized like a sum, then it will fall far behind of having a concrete type.

You can use e.g. filter syntax:

julia> x = [1, 2, nothing, 3]
4-element Vector{Union{Nothing, Int64}}:
 1
 2
  nothing
 3

julia> sum(x for x in a if x !== nothing)
6

But if you really need to get a zero of the same type as expensive_function then you need some way to calculate the resulting element type yourself.

2 Likes

Thanks for the benchmark!

But if you really need to get a zero of the same type as expensive_function then you need some way to calculate the resulting element type yourself.

Could I do this with a generated function? Something like:

@generated function f(x)
    expensive_type = typeof(expensive_function(zero(x)))
    quote 
        if x < 0 
            return zero($expensive_type) 
        end 
        return expensive_function(x)
    end 
end

Or is this bound to break in some way?

That is not a good idea. See the docs (Metaprogramming · The Julia Language) for some things you shouldn’t do.

See the docs (Metaprogramming · The Julia Language ) for some things you shouldn’t do.

I did that before posting because I was indeed suspicious and I know generated functions should be used sparingly and with caution, but I couldn’t put my finger on any concrete problem with the code.

I thought about if it observes any global mutable state, but it just calls some functions and the docs just mention “Calling any function that is defined after the body of the generated function” as bad, which to me sounds like calling other functions is ok.

So as far as I can see, the code does nothing that the docs forbid. Or am I missing something?

1 Like