How to work with type unstable functions?

The question is: consider that I have a function, which is by the nature of the problem should return a heterogeneous result. Is there any way, approach, or design ideas, how to build a workflow that can work with these functions and still be performant?

As a very simplified example to illustrate a problem, consider the problem of finding roots of a quadratic equation. We have a polynomial of the second degree, which is defined by three variables, a, b, c.

f(x) = a x^2 + b x + c

Solution of the equation f(x) = 0 , can have 0, 1, or 2 roots. Also, there is a special degenerate case when a = 0, which also can have 0 or 1 root, but since the behaviour of the function is different compared to quadratic case (it changes sign around root point), we consider it as a special case. So, all in all we have 4 different outputs.

Also, to make this example slightly more complete, consider that we want to calculate number of roots for a set of polynomials.

using StableRNGs

rng = StableRNG(2021)
polynomials = map(_ -> Tuple(rand(rng, 3)), 1:1000);

One approach, that I saw tries to utilize Julia dispatch system and do something like this

abstract type AbstractRoots end
struct NoRoots <: AbstractRoots end
struct SingleRoot <: AbstractRoots 
   x1::Float64
end
struct SingleLinearRoot <: AbstractRoots
  x1::Float64
end
struct TwoRoots <: AbstractRoots
  x1::Float64
  x2::Float64
end

function solve(poly)
    a, b, c = poly
    if a == 0
        if b == 0
            return NoRoots() # It should be infinity but 4 types is more than enough already
        else
            return SingleLinearRoot(-c/b)
        end
    end

    D = b^2 - 4*a*c
    if D < 0
        return NoRoots()
    elseif D == 0
        return SingleRoot(-b/(2*a))
    else
        sD = sqrt(D)
        return TwoRoots((-b + sD)/(2*a), (-b - sD)/(2*a))
    end
end

In this approach, Julia type system is used as a sort of event dispatch: one function (solve) emits various types and some other function on the other side catches the result and process them accordingly. And multiple dispatch is used as an internal tool to make, well, dispatch.

cnt(::NoRoots) = 0
cnt(::TwoRoots) = 2
cnt(::AbstractRoots) = 1

sum(cnt, solve.(polynomials)) # 546

It looks nice on a first glance, but to be honest, I consider it as an antipattern, since it inherently introduce type instability.

julia> using BenchmarkTools

julia> @btime sum(cnt, solve.($polynomials))
  49.333 ÎĽs (1051 allocations: 33.25 KiB)

julia> @code_warntype solve(polynomials[1])
MethodInstance for solve(::Tuple{Float64, Float64, Float64})
  from solve(poly) in Main at REPL[68]:1
Arguments
  #self#::Core.Const(solve)
  poly::Tuple{Float64, Float64, Float64}
Locals
  @_3::Int64
  sD::Float64
  D::Float64
  c::Float64
  b::Float64
  a::Float64
Body::Any

But what alternative do we have? I can think only of introducing some sort of flag system, i.e.

@enum Status NO_ROOTS SINGLE_ROOT SINGLE_LINE_ROOT TWO_ROOTS
function solve2(poly)
    a, b, c = poly
    if a == 0
        if b == 0
            return NO_ROOTS, (0., 0.)
        else
            return SINGLE_LINE_ROOT, (-c/b, 0.)
        end
    end

    D = b^2 - 4*a*c
    if D < 0
        return NO_ROOTS, (0., 0.)
    elseif D == 0
        return SINGLE_ROOT, (-b/(2*a), 0.)
    else
        sD = sqrt(D)
        return TWO_ROOTS, ((-b + sD)/(2*a), (-b - sD)/(2*a))
    end
end

function cntroots(poly)
    status, points = solve2(poly)
    return status == NO_ROOTS ? 0 : status == TWO_ROOTS ? 2 : 1
end

This is much better performance wise

julia> @btime sum(cntroots, $polynomials)
  4.438 ÎĽs (0 allocations: 0 bytes)

But this approach has its drawbacks

  1. Flag system looks rather inflexible. There is no type checking, users have to track manually all namings and process various options.
  2. Generating return is complicated because one has to generate useless (0., 0.) in NO_ROOTS branch. It doesn’t look big here, but in a more general case generation of such “empty” results can have it’s own overhead. What if we are trying to solve cubic equation? Equation of the 4th degree? Produce (0., 0., 0.) just to be aligned with the general case looks extraneous and can introduce it’s own slight bugs.

So, what options do we have? Dispatch system abuse and flag system, is there anything outside of these two choices? And take into account that polynomial roots are only toy example, real life examples can be much more complicated, so it’s more interesting to hear some guidance and approaches to this type of problems.

1 Like

Well, after some considerations, it looks like there is a third option. If it is known beforehand, that values are of little interest, and they should be processed to get some result, we can just do it in place.

function solve3(f, poly)
    a, b, c = poly
    if a == 0
        if b == 0
            return f(NoRoots()) # It should be infinity but 4 types is more than enough already
        else
            return f(SingleLinearRoot(-c/b))
        end
    end

    D = b^2 - 4*a*c
    if D < 0
        return f(NoRoots())
    elseif D == 0
        return f(SingleRoot(-b/(2*a)))
    else
        sD = sqrt(D)
        return f(TwoRoots((-b + sD)/(2*a), (-b - sD)/(2*a)))
    end
end

This small change uses multiple dispatch system without abusing it (I guess?). And it is even faster than flag system, probably because Julia can inline functions and optimize resulting code.

julia> @btime sum(x -> solve3(cnt, x), $polynomials)
  1.554 ÎĽs (0 allocations: 0 bytes)

I can’t see any problems with this approach. Am I wrong or it is an idiomatic approach to this issue?

5 Likes

There is another way that, although not being as fast as the 2nd and 3rd ways, is faster than the first version:

function numroots(sol)
    if sol isa NoRoots
        0
    elseif sol isa SingleRoot
        1
    elseif sol isa SingleLinearRoot
        1
    elseif sol isa TwoRoots
        2
    else
        0
    end
end

@btime sum(numroots.(solve.($polynomials)))  # 6.675 ÎĽs (643 allocations: 28.00 KiB)
@btime sum(x -> numroots(solve(x)) , $polynomials)  # 11.400 ÎĽs (638 allocations: 19.94 KiB)

@btime sum(cnt, solve.($polynomials))  # 44.900 ÎĽs (1787 allocations: 56.20 KiB)
@btime sum(x -> solve3(cnt, x), $polynomials)  # 1.360 ÎĽs (0 allocations: 0 bytes)