Simple and fast bisection

Hi, as the title says I’m trying to build a simple and fast bisection algorithm. I know the Roots.jl package already implements it, yet I’m doing it both for pedagogical purposes and because I then plan to expand it with different algorithms (I like having full control over what my code can or cannot do).

The way I thought of doing it is very simple: I have a Point object that stores both coordinates of a point and at each iteration I update one of the two endpoints based on the sign at the midpoint, where the function interval!() does most of the work.

I tried to strip down the code to its bare minimum:

mutable struct Point

function interval!(f::Function, pl::Point, ph::Point)

    xm = (pl.x + ph.x)/2
    ym = f(xm)

    if ym < 0 
        (pl.x, pl.y) = (xm, ym)
        (ph.x, ph.y) = (xm, ym)
    return (xm, ym)

function root_1d(f::Function, x1::Real, x2::Real;
                 x_tol::Float64 = 1e-6, f_tol::Float64 = 1e-6, max_iter::Int = 3000)

    pl = Point(x1,f(x1))
    ph = Point(x2,f(x2))
    pl.y*ph.y >= 0 && error("the interval provided does not bracket a root")

    # Ensures left endpoint evaluates to negative
    pl.y > 0 && ((pl, ph) = (ph, pl))

    xm = 0
    ym = 0
    iter = 0
    stop = false
    while ~stop
        iter += 1

        (xm, ym) = interval!(f,pl,ph)

        # Convergence check
        if ph.x - pl.x <= x_tol*(1 + abs(pl.x) + abs(ph.x)) || abs(ym) <= f_tol || iter >= max_iter
            stop = true
            iter >= max_iter && warn("root_1d did not converge within max_iter")
    return xm, ym

Now, one thing I noticed is that the interval!() function is not type stable. Running

f(x) = x^3
pl = Point(-1, 1)
ph = Point(2, 8)

@code_warntype interval!(f, pl, ph)

tells me the result is of type Tuple{Any,Any}. From what I understand this comes from the fact that it cannot infer the sum of two reals (???) but this seems a bit weird.
I’m not sure how I should fix this. I tried making Point a parametric type but then I cannot change it in place because the type of its argument might change from one iteration to the next. At the same time I thought having the Point object might be nice to reduce memory allocation.

Any help is much appreciated.

Your main problem is

mutable struct Point

Here, any Real can be stored in the Point so the compiler cannot make any assumption at all about what the content will be. Note that a Real is an abstract type (and not Float64). Instead, you should parameterize this so that the compiler can know the concrete type when we are running the code:

struct Point{T}

There are also some other tweaks to make sure that things are typestable and using immutable values instead of mutable ones (typically easier to optimize).

Performance Tips · The Julia Language would be a good read.

I took some liberty and rewrote the code slightly:

struct Point{T}

function interval(f::Function, pl::Point, ph::Point)
    xm = (pl.x + ph.x)/2
    ym = f(xm)
    m = Point(xm, ym)

    if ym < 0 
        return m, ph, m
        return pl, m, m

function root_1d(f::Function, x1::Real, x2::Real;
                 x_tol::Float64 = 1e-6, f_tol::Float64 = 1e-6, max_iter::Int = 3000)

    x1, x2 = float(x1), float(x2)
    pl = Point(x1, f(x1))
    ph = Point(x2, f(x2))
    pl.y*ph.y >= 0 && error("the interval provided does not bracket a root")

    # Ensures left endpoint evaluates to negative
    pl.y > 0 && ((pl, ph) = (ph, pl))

    iter = 0
    while true
        iter += 1
        iter >= max_iter && error("root_1d did not converge within max_iter")

        pl, ph, m = interval(f, pl, ph)

        if ph.x - pl.x <= x_tol*(1 + abs(pl.x) + abs(ph.x)) || abs(m.y) <= f_tol || iter >= max_iter
            return m

and this should be fast:

julia> @btime root_1d(x-> x^3 - 5, -4, 2)
  185.212 ns (1 allocation: 32 bytes)
Point{Float64}(1.709975242614746, -6.176066356999854e-6)

Thanks a lot.
I figured that was the problem but I thought using a mutable struct would have been better for memory allocation purposes instead of creating new immutables at each iteration. I guess I was wrong :smile: .
I changed something more. In particular I changed Point to:

struct Point{T1<:Real,T2<:Real}

which allows x and y to be of different types and I also removed the x1, x2 = float(x1), float(x2) line. It now seems to be slightly faster.

My only concern is that @code_warntype now returns a Union type. In particular, with the changes I get

julia> @code_warntype root_1d(x -> x^3 - 5, -4,2)
Body::Union{Point{Float64,Float64}, Point{Int64,Int64}}
5 1 ─ %1 = invoke Main.:(#root_1d#3)(1.0e-6::Float64, 1.0e-6::Float64, 3000::Int64, _1::Function, _2::getfield(Main, Symbol("##4#5")), _3::Int64, _4::Int64)::Union{Point{Float64,Float64}, Point{Int64,Int64}}
  └──      return %1

Is this a problem? Did I understand correctly that Union types such as this ones are not a problem in 0.7 anymore?

Finally, I have one question regarding the code I had before, which can be condensed as follows: I understand the compiler cannot make any assumption about what the concrete type of Point is. What I don’t understand is why I get Any when summing such Reals. In particular:

mutable struct Point

f(p1,p2) = p1.x + p2.x

pl = Point(-1)
ph = Point(10)
julia> @code_warntype f(pl,ph)
1 1 ─ %1 = (Base.getfield)(p1, :x)::Real                                               │╻ getproperty
  │   %2 = (Base.getfield)(p2, :x)::Real                                               ││
  │   %3 = (%1 + %2)::Any                                                              │ 
  └──      return %3                                                                   │ 

at %3 shouldn’t I be getting a real anyway?

Thanks again :slight_smile:

Nothing in the language stops anyone from defining a type which is <: Real but for which addition produces something which is not <: Real (even if that would mathematically weird). So there is no way in general for Julia to infer what type the addition of two totally unknown types which are both <: Real would produce.

julia> struct Iamabadperson <: Real

julia> Base.:+(x::Iamabadperson, y::Real) = x.message

julia> mutable struct Point

julia> f(p1,p2) = p1.x + p2.x
f (generic function with 1 method)

julia> pl = Point(Iamabadperson("You can do anything if you want."))
Point(Iamabadperson("You can do anything if you want."))

julia> ph = Point(10)

julia> f(pl, ph)
"You can do anything if you want."
Well, it can be a problem for people using your root_1d that they get a type unstable result. While Julia is better at Unions in 0.7+ there are limits and if it is possible, type stability is definitely something to strive for.

I guess I was reasoning just in a mathematical sense. That does sound weird but I understand nothing stops me to do something like that.

With respect to type stability, I understand float() would solve the problem, but I always try to avoid conversion when I can. In particular, I like to write functions that are as general as possible, and in this example this would create problems if the function f were defined only for some specific type.

For instance, if I define f(x::Rational) = x^2 - 1, with the line x1, x2 = float(x1), float(x2) I get

julia> root_1d(f,1//30,10//3)

ERROR: MethodError: no method matching f(::Float64)

while commenting it out generates the right result (or close to it).

What I’d like to do here is get a type stable version of the above but without having to restrict it to use floats.

P.S. how do I add quotes of previous posts?

You don’t need any restrictions.

julia> foo(a,b,c) = @fastmath a * b + c
foo (generic function with 1 method)

julia> @code_warntype foo(2.0, 3.0, 4.0)
1 1 ─ %1 = (Base.FastMath.mul_float_fast)(a, b)::Float64                                                            │╻ mul_fast
  │   %2 = (Base.FastMath.add_float_fast)(%1, c)::Float64                                                           │╻ add_fast
  └──      return %2                                                                                                │ 

julia> @code_warntype foo(2, 3, 4)
1 1 ─ %1 = (Base.mul_int)(a, b)::Int64                                                                             │╻╷ mul_fast
  │   %2 = (Base.add_int)(%1, c)::Int64                                                                            ││╻  +
  └──      return %2                                                                                               │  

julia> @code_warntype foo(2//1, 3//1, 4//1)
1 1 ─ %1 = invoke Base.FastMath.:*(_2::Rational{Int64}, _3::Rational{Int64})::Rational{Int64}                       │╻ mul_fast
  │   %2 = invoke Base.FastMath.:+(%1::Rational{Int64}, _4::Rational{Int64})::Rational{Int64}                       │╻ add_fast
  └──      return %2                                                                    

Type declarations are restrictions, not hints or information.
A separate version of foo is compiled for every set of input argument types, therefore each of the above will produce the optimal assembly and is type stable (as seen above):

julia> @code_native foo(2.0, 3.0, 4.0)
; Function foo {
; Location: REPL[1]:1
; Function add_fast; {
; Location: REPL[1]:1
	vfmadd213sd	%xmm2, %xmm1, %xmm0
	nopw	%cs:(%rax,%rax)

julia> @code_native foo(2, 3, 4)
; Function foo {
; Location: REPL[1]:1
; Function mul_fast; {
; Location: fastmath.jl:257
; Function *; {
; Location: REPL[1]:1
	imulq	%rsi, %rdi
; Function add_fast; {
; Location: fastmath.jl:257
; Function +; {
; Location: int.jl:53
	leaq	(%rdi,%rdx), %rax
	nopl	(%rax)

Yes, that I understand.

My problem is that commenting out

 x1, x2 = float(x1), float(x2)

makes my function type unstable, but including it restricts its usage to functions that accept floats as arguments.

In fact:

struct Point{T1<:Real,T2<:Real}

function interval(f::Function, pl::Point, ph::Point)
    xm = (pl.x + ph.x)/2
    ym = f(xm)
    m = Point(xm, ym)

    if ym < 0 
        return m, ph, m
        return pl, m, m

function root_1d_float(f::Function, x1::Real, x2::Real;
                 x_tol::Float64 = 1e-6, f_tol::Float64 = 1e-6, max_iter::Int = 3000)

    x1, x2 = float(x1), float(x2)
    pl = Point(x1, f(x1))
    ph = Point(x2, f(x2))
    pl.y*ph.y >= 0 && error("the interval provided does not bracket a root")

    # Ensures left endpoint evaluates to negative
    pl.y > 0 && ((pl, ph) = (ph, pl))

    iter = 0
    while true
        iter += 1
        iter >= max_iter && error("root_1d did not converge within max_iter")

        pl, ph, m = interval(f, pl, ph)

        if ph.x - pl.x <= x_tol*(1 + abs(pl.x) + abs(ph.x)) || abs(m.y) <= f_tol || iter >= max_iter
            return (m.x, m.y)

function root_1d_all(f::Function, x1::Real, x2::Real;
                 x_tol::Float64 = 1e-6, f_tol::Float64 = 1e-6, max_iter::Int = 3000)

    # x1, x2 = float(x1), float(x2)
    pl = Point(x1, f(x1))
    ph = Point(x2, f(x2))
    pl.y*ph.y >= 0 && error("the interval provided does not bracket a root")

    # Ensures left endpoint evaluates to negative
    pl.y > 0 && ((pl, ph) = (ph, pl))

    iter = 0
    while true
        iter += 1
        iter >= max_iter && error("root_1d did not converge within max_iter")

        pl, ph, m = interval(f, pl, ph)

        if ph.x - pl.x <= x_tol*(1 + abs(pl.x) + abs(ph.x)) || abs(m.y) <= f_tol || iter >= max_iter
            return (m.x, m.y)

gives the following results:

julia> @code_warntype root_1d_float(x -> x^2 - 1, 0, 2)
4 1 ─ %1 = invoke Main.:(#root_1d_float#3)(1.0e-6::Float64, 1.0e-6::Float64, 3000::Int64, _1::Function, _2::getfield(Main, Symbol("##4#5")), _3::Int64, _4::Int64)::Tuple{Float64,Float64}
  └──      return %1                                                                                │

julia> root_1d_float(x::Rational -> x^2 - 1, 0//1, 2//1)
ERROR: MethodError: no method matching (::getfield(Main, Symbol("##6#7")))(::Float64)
Closest candidates are:
  #6(::Rational) at REPL[9]:1

i.e. type stable but only works if f accepts Floats.

julia> root_1d_all(x::Rational -> x^2 - 1, 0//1, 2//1)
(1//1, 0//1)

julia> @code_warntype root_1d_all(x -> x^2 - 1, 0, 2)
Body::Tuple{Union{Float64, Int64},Union{Float64, Int64}}

i.e. type unstable but accepts more general functions.

So, I understood your point but I’m not sure I understood how to apply it to my problem.