Simple and fast bisection

Hi, as the title says I’m trying to build a simple and fast bisection algorithm. I know the Roots.jl package already implements it, yet I’m doing it both for pedagogical purposes and because I then plan to expand it with different algorithms (I like having full control over what my code can or cannot do).

The way I thought of doing it is very simple: I have a Point object that stores both coordinates of a point and at each iteration I update one of the two endpoints based on the sign at the midpoint, where the function interval!() does most of the work.

I tried to strip down the code to its bare minimum:

mutable struct Point
    x::Real
    y::Real
end

function interval!(f::Function, pl::Point, ph::Point)

    xm = (pl.x + ph.x)/2
    ym = f(xm)

    if ym < 0 
        (pl.x, pl.y) = (xm, ym)
    else
        (ph.x, ph.y) = (xm, ym)
    end
    return (xm, ym)
end

function root_1d(f::Function, x1::Real, x2::Real;
                 x_tol::Float64 = 1e-6, f_tol::Float64 = 1e-6, max_iter::Int = 3000)

    pl = Point(x1,f(x1))
    ph = Point(x2,f(x2))
    pl.y*ph.y >= 0 && error("the interval provided does not bracket a root")

    # Ensures left endpoint evaluates to negative
    pl.y > 0 && ((pl, ph) = (ph, pl))

    xm = 0
    ym = 0
    iter = 0
    stop = false
    while ~stop
        iter += 1

        (xm, ym) = interval!(f,pl,ph)

        # Convergence check
        if ph.x - pl.x <= x_tol*(1 + abs(pl.x) + abs(ph.x)) || abs(ym) <= f_tol || iter >= max_iter
            stop = true
            iter >= max_iter && warn("root_1d did not converge within max_iter")
        end
    end
    return xm, ym
end

Now, one thing I noticed is that the interval!() function is not type stable. Running

f(x) = x^3
pl = Point(-1, 1)
ph = Point(2, 8)

@code_warntype interval!(f, pl, ph)

tells me the result is of type Tuple{Any,Any}. From what I understand this comes from the fact that it cannot infer the sum of two reals (???) but this seems a bit weird.
I’m not sure how I should fix this. I tried making Point a parametric type but then I cannot change it in place because the type of its argument might change from one iteration to the next. At the same time I thought having the Point object might be nice to reduce memory allocation.

Any help is much appreciated.

Your main problem is

mutable struct Point
    x::Real
    y::Real
end

Here, any Real can be stored in the Point so the compiler cannot make any assumption at all about what the content will be. Note that a Real is an abstract type (and not Float64). Instead, you should parameterize this so that the compiler can know the concrete type when we are running the code:

struct Point{T}
    x::T
    y::T
end

There are also some other tweaks to make sure that things are typestable and using immutable values instead of mutable ones (typically easier to optimize).

Performance Tips · The Julia Language would be a good read.

I took some liberty and rewrote the code slightly:

struct Point{T}
    x::T
    y::T
end

function interval(f::Function, pl::Point, ph::Point)
    xm = (pl.x + ph.x)/2
    ym = f(xm)
    m = Point(xm, ym)

    if ym < 0 
        return m, ph, m
    else
        return pl, m, m
    end
end

function root_1d(f::Function, x1::Real, x2::Real;
                 x_tol::Float64 = 1e-6, f_tol::Float64 = 1e-6, max_iter::Int = 3000)

    x1, x2 = float(x1), float(x2)
    pl = Point(x1, f(x1))
    ph = Point(x2, f(x2))
    pl.y*ph.y >= 0 && error("the interval provided does not bracket a root")

    # Ensures left endpoint evaluates to negative
    pl.y > 0 && ((pl, ph) = (ph, pl))

    iter = 0
    while true
        iter += 1
        iter >= max_iter && error("root_1d did not converge within max_iter")

        pl, ph, m = interval(f, pl, ph)

        if ph.x - pl.x <= x_tol*(1 + abs(pl.x) + abs(ph.x)) || abs(m.y) <= f_tol || iter >= max_iter
            return m
        end
    end
end

and this should be fast:

julia> @btime root_1d(x-> x^3 - 5, -4, 2)
  185.212 ns (1 allocation: 32 bytes)
Point{Float64}(1.709975242614746, -6.176066356999854e-6)
7 Likes

Thanks a lot.
I figured that was the problem but I thought using a mutable struct would have been better for memory allocation purposes instead of creating new immutables at each iteration. I guess I was wrong :smile: .
I changed something more. In particular I changed Point to:

struct Point{T1<:Real,T2<:Real}
    x::T1
    y::T2
end

which allows x and y to be of different types and I also removed the x1, x2 = float(x1), float(x2) line. It now seems to be slightly faster.

My only concern is that @code_warntype now returns a Union type. In particular, with the changes I get

julia> @code_warntype root_1d(x -> x^3 - 5, -4,2)
Body::Union{Point{Float64,Float64}, Point{Int64,Int64}}
5 1 ─ %1 = invoke Main.:(#root_1d#3)(1.0e-6::Float64, 1.0e-6::Float64, 3000::Int64, _1::Function, _2::getfield(Main, Symbol("##4#5")), _3::Int64, _4::Int64)::Union{Point{Float64,Float64}, Point{Int64,Int64}}
  └──      return %1

Is this a problem? Did I understand correctly that Union types such as this ones are not a problem in 0.7 anymore?

Finally, I have one question regarding the code I had before, which can be condensed as follows: I understand the compiler cannot make any assumption about what the concrete type of Point is. What I don’t understand is why I get Any when summing such Reals. In particular:

mutable struct Point
    x::Real
end

f(p1,p2) = p1.x + p2.x

pl = Point(-1)
ph = Point(10)
julia> @code_warntype f(pl,ph)
Body::Any
1 1 ─ %1 = (Base.getfield)(p1, :x)::Real                                               │╻ getproperty
  │   %2 = (Base.getfield)(p2, :x)::Real                                               ││
  │   %3 = (%1 + %2)::Any                                                              │ 
  └──      return %3                                                                   │ 


at %3 shouldn’t I be getting a real anyway?

Thanks again :slight_smile:

Nothing in the language stops anyone from defining a type which is <: Real but for which addition produces something which is not <: Real (even if that would mathematically weird). So there is no way in general for Julia to infer what type the addition of two totally unknown types which are both <: Real would produce.

1 Like
julia> struct Iamabadperson <: Real
           message::String
       end

julia> Base.:+(x::Iamabadperson, y::Real) = x.message

julia> mutable struct Point
           x::Real
       end

julia> f(p1,p2) = p1.x + p2.x
f (generic function with 1 method)

julia> pl = Point(Iamabadperson("You can do anything if you want."))
Point(Iamabadperson("You can do anything if you want."))

julia> ph = Point(10)
Point(10)

julia> f(pl, ph)
"You can do anything if you want."
1 Like

Well, it can be a problem for people using your root_1d that they get a type unstable result. While Julia is better at Unions in 0.7+ there are limits and if it is possible, type stability is definitely something to strive for.

I guess I was reasoning just in a mathematical sense. That does sound weird but I understand nothing stops me to do something like that.

With respect to type stability, I understand float() would solve the problem, but I always try to avoid conversion when I can. In particular, I like to write functions that are as general as possible, and in this example this would create problems if the function f were defined only for some specific type.

For instance, if I define f(x::Rational) = x^2 - 1, with the line x1, x2 = float(x1), float(x2) I get

julia> root_1d(f,1//30,10//3)

ERROR: MethodError: no method matching f(::Float64)

while commenting it out generates the right result (or close to it).

What I’d like to do here is get a type stable version of the above but without having to restrict it to use floats.

P.S. how do I add quotes of previous posts?

You don’t need any restrictions.

julia> foo(a,b,c) = @fastmath a * b + c
foo (generic function with 1 method)

julia> @code_warntype foo(2.0, 3.0, 4.0)
Body::Float64
1 1 ─ %1 = (Base.FastMath.mul_float_fast)(a, b)::Float64                                                            │╻ mul_fast
  │   %2 = (Base.FastMath.add_float_fast)(%1, c)::Float64                                                           │╻ add_fast
  └──      return %2                                                                                                │ 

julia> @code_warntype foo(2, 3, 4)
Body::Int64
1 1 ─ %1 = (Base.mul_int)(a, b)::Int64                                                                             │╻╷ mul_fast
  │   %2 = (Base.add_int)(%1, c)::Int64                                                                            ││╻  +
  └──      return %2                                                                                               │  

julia> @code_warntype foo(2//1, 3//1, 4//1)
Body::Rational{Int64}
1 1 ─ %1 = invoke Base.FastMath.:*(_2::Rational{Int64}, _3::Rational{Int64})::Rational{Int64}                       │╻ mul_fast
  │   %2 = invoke Base.FastMath.:+(%1::Rational{Int64}, _4::Rational{Int64})::Rational{Int64}                       │╻ add_fast
  └──      return %2                                                                    

Type declarations are restrictions, not hints or information.
A separate version of foo is compiled for every set of input argument types, therefore each of the above will produce the optimal assembly and is type stable (as seen above):

julia> @code_native foo(2.0, 3.0, 4.0)
	.text
; Function foo {
; Location: REPL[1]:1
; Function add_fast; {
; Location: REPL[1]:1
	vfmadd213sd	%xmm2, %xmm1, %xmm0
;}
	retq
	nopw	%cs:(%rax,%rax)
;}

julia> @code_native foo(2, 3, 4)
	.text
; Function foo {
; Location: REPL[1]:1
; Function mul_fast; {
; Location: fastmath.jl:257
; Function *; {
; Location: REPL[1]:1
	imulq	%rsi, %rdi
;}}
; Function add_fast; {
; Location: fastmath.jl:257
; Function +; {
; Location: int.jl:53
	leaq	(%rdi,%rdx), %rax
;}}
	retq
	nopl	(%rax)
;}

Yes, that I understand.

My problem is that commenting out

 x1, x2 = float(x1), float(x2)

makes my function type unstable, but including it restricts its usage to functions that accept floats as arguments.

In fact:


struct Point{T1<:Real,T2<:Real}
    x::T1
    y::T2
end

function interval(f::Function, pl::Point, ph::Point)
    xm = (pl.x + ph.x)/2
    ym = f(xm)
    m = Point(xm, ym)

    if ym < 0 
        return m, ph, m
    else
        return pl, m, m
    end
end

function root_1d_float(f::Function, x1::Real, x2::Real;
                 x_tol::Float64 = 1e-6, f_tol::Float64 = 1e-6, max_iter::Int = 3000)

    x1, x2 = float(x1), float(x2)
    pl = Point(x1, f(x1))
    ph = Point(x2, f(x2))
    pl.y*ph.y >= 0 && error("the interval provided does not bracket a root")

    # Ensures left endpoint evaluates to negative
    pl.y > 0 && ((pl, ph) = (ph, pl))

    iter = 0
    while true
        iter += 1
        iter >= max_iter && error("root_1d did not converge within max_iter")

        pl, ph, m = interval(f, pl, ph)

        if ph.x - pl.x <= x_tol*(1 + abs(pl.x) + abs(ph.x)) || abs(m.y) <= f_tol || iter >= max_iter
            return (m.x, m.y)
        end
    end
end

function root_1d_all(f::Function, x1::Real, x2::Real;
                 x_tol::Float64 = 1e-6, f_tol::Float64 = 1e-6, max_iter::Int = 3000)

    # x1, x2 = float(x1), float(x2)
    pl = Point(x1, f(x1))
    ph = Point(x2, f(x2))
    pl.y*ph.y >= 0 && error("the interval provided does not bracket a root")

    # Ensures left endpoint evaluates to negative
    pl.y > 0 && ((pl, ph) = (ph, pl))

    iter = 0
    while true
        iter += 1
        iter >= max_iter && error("root_1d did not converge within max_iter")

        pl, ph, m = interval(f, pl, ph)

        if ph.x - pl.x <= x_tol*(1 + abs(pl.x) + abs(ph.x)) || abs(m.y) <= f_tol || iter >= max_iter
            return (m.x, m.y)
        end
    end
end

gives the following results:

julia> @code_warntype root_1d_float(x -> x^2 - 1, 0, 2)
Body::Tuple{Float64,Float64}
4 1 ─ %1 = invoke Main.:(#root_1d_float#3)(1.0e-6::Float64, 1.0e-6::Float64, 3000::Int64, _1::Function, _2::getfield(Main, Symbol("##4#5")), _3::Int64, _4::Int64)::Tuple{Float64,Float64}
  └──      return %1                                                                                │

julia> root_1d_float(x::Rational -> x^2 - 1, 0//1, 2//1)
ERROR: MethodError: no method matching (::getfield(Main, Symbol("##6#7")))(::Float64)
Closest candidates are:
  #6(::Rational) at REPL[9]:1

i.e. type stable but only works if f accepts Floats.
And:

julia> root_1d_all(x::Rational -> x^2 - 1, 0//1, 2//1)
(1//1, 0//1)

julia> @code_warntype root_1d_all(x -> x^2 - 1, 0, 2)
Body::Tuple{Union{Float64, Int64},Union{Float64, Int64}}

i.e. type unstable but accepts more general functions.

So, I understood your point but I’m not sure I understood how to apply it to my problem.