Is there a faster bisection root solver (that uses atol)?

question

#1

I have a code that uses bisection root-solvers – but with function calls that take a while.

Does anyone have an alternative to the Roots.jl version that accepts atol and is as performant as the original (besides the obvious tolerance truncation speed-up).


The Roots.jl function I’m talking about is:

find_zero(f, [cur_a, cur_b], Bisection()

#2

Bisection should be straightforward to implement yourself to your liking, or?

using Statistics
# Note, the benchmark that follows uses atol = sqrt(eps())
# but I decided to update it later to reflect the value of the upper bracket.
@inline bisect(a,b) = a/2 + b/2
function bisection(f, a_, b_, atol = 2eps(promote_type(typeof(b_),Float64)(b_)); increasing = sign(f(b_)))
    a_, b_ = minmax(a_, b_)
    c = middle(a_,b_)
    z = f(c) * increasing
    if z > 0 #
        b = c
        a = typeof(b)(a_)
    else
        a = c
        b = typeof(a)(b_)
    end
    while abs(a - b) > atol
        c = middle(a,b)
        if f(c) * increasing > 0 #
            b = c
        else
            a = c
        end
    end
    a, b
end

Simple example, taken from Roots.jl’s documentation:

julia> f(x) = exp(x) - x^4

julia> bisection(f,8,9)
(8.613169446587563, 8.613169461488724)

julia> using BenchmarkTools

julia> @benchmark bisection(f,8,9)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.391 μs (0.00% GC)
  median time:      1.472 μs (0.00% GC)
  mean time:        1.468 μs (0.00% GC)
  maximum time:     4.717 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> using Roots

julia> find_zero(f, [8, 9], Bisection())
8.613169456441398

julia> bracket = [8,9];

julia> @benchmark find_zero(f, $bracket, Bisection())
BenchmarkTools.Trial: 
  memory estimate:  10.15 KiB
  allocs estimate:  595
  --------------
  minimum time:     24.553 μs (0.00% GC)
  median time:      27.303 μs (0.00% GC)
  mean time:        34.235 μs (19.91% GC)
  maximum time:     52.482 ms (99.89% GC)
  --------------
  samples:          10000
  evals/sample:     1

Roots.jl is poorly optimized. The code is not type stable, and in the case of Bisection also lots of try-catches.
Anything you implement yourself is likely to be many times faster.

EDIT:
Fairer benchmark:

julia> a, b = bisection(f,8,9,sqrt(eps())/400000)
(8.613169456441398, 8.613169456441426)

julia> (a, b) .- find_zero(f, [8, 9], Bisection())# want our answer to be at least as accurate
(0.0, 2.842170943040401e-14)

julia> @benchmark bisection(f,8,9,sqrt(eps())/400000)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     2.309 μs (0.00% GC)
  median time:      2.316 μs (0.00% GC)
  mean time:        2.394 μs (0.00% GC)
  maximum time:     6.093 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9

#3

Try Roots.bisection(f,a,b,xatol=...) it is very similar to @elrod’s code. His point about type stability is being addressed, hopefully in a PR soon.


#4

Roots.jl accepts absolute tolerance parameters. But Bisection isn’t the fastest. I would suggest using FalsePosition() if you want to minimize the number of function calls.


#5

Similar story with FalsePosition.

I have modified code taken from their implementation here:

The biggest change is that I made FalsePosition parametric on the reduction factor, and modified the reduction function. They instead looked up the reduction function to call in a non-constant global dictionary.

I’m not saying that’s the best solution, but @j_verzani, I’d suggest considering an approach like that to inline the reduction call, rather than performing a dictionary look up, and type unstable dynamic dispatch every time, etc.


#6

Sounds like a good PR waiting to happen.


#7

The reason I hadn’t was because my changes seemed to result from a different vision of the code. Ie, runtime vs compile time in particular – all I wanted is the former (plus type stability in its own right, because calls to root finding were internal meaning instabilities would risk propagating).

This was reflected in choices like abstract types and unions in struct fields, and deliberate @nospecialize put on lots of arguments.

As an aside, if I understand FunctionWrappers.jl correctly, it should provide type stable wrappers to functions that bypass specialization? If so, taking advantage of that would make a great PR. Haven’t experimented with that yet.

But if those at Roots are happy with a focus on run-time performance (and I realistically wouldn’t expect recompiling for different objectives to be much of an issue), I’d be happy to make a PR in a couple weeks.

Still preparing a (5 minute) talk for JSM, with a title whose corresponding method isn’t working yet.


#8

I’m all for that atol algorithm, but it seems like it doesn’t find roots near NaN and Inf like the one from Roots.jl.

The problem with using Roots.jl, though, is that it will stall if it thinks there is a root, but can’t find one (while in bisection mode)


#9

Sounds good to me. I also find that Roots.jl has too much setup overhead before each call. I often need to run it a million times a second but the overhead in Roots.jl can be higher than running the small functions I need the root of, so its a significant portion of my whole model run.


#10

What is a root near NaN?

Replacing the bisect function with this should help with the Inf issue:

@inline bisect(a,b) = a/2 + b/2

Now:

julia> prevfloat(typemax(Float64))/2 + prevfloat(typemax(Float64),3)/2
1.7976931348623155e308

julia> ans == prevfloat(typemax(Float64),2)
true

It’s a super simple algorithm, so you should be able to diagnose why something’s going wrong and make it robust to handle that.
(Of course, unit tests would help to make sure you’re not suddenly failing on some other edge case.)

@Raf
Okay, yeah, simulations with millions of calls are my use case too.
I’m guessing for a lot of other people too, so I’d support a simulation-focused version as at least being an option.

For the “low compile time” version, perhaps for when someone is trying to debug a function they’re root finding, and thus constantly redefining the method -> recompiling,

# 0.7, cleaned up dep warnings / started with --depwarn=no
julia> using Test, BenchmarkTools

julia> import FunctionWrappers: FunctionWrapper

julia> const F64F64Func = FunctionWrapper{Float64,Tuple{Float64}}
FunctionWrapper{Float64,Tuple{Float64}}

julia> f(x) = exp(x) - x^4
f (generic function with 1 method)

julia> FWf = F64F64Func(f)
FunctionWrapper{Float64,Tuple{Float64}}(Ptr{Nothing} @0x00007f2fa4060ca0, Ptr{Nothing} @0x00007f2fa71a47f0, Base.RefValue{typeof(f)}(f), typeof(f))

julia> @inferred f(8.6)
-38.422008637020554

julia> @inferred FWf(8.6)
-38.422008637020554

julia> typeof(FWf)
FunctionWrapper{Float64,Tuple{Float64}}

julia> struct UnstableWrap
           f::Function
       end

julia> (UW::UnstableWrap)(x...) = UW.f(x...)

julia> uf(8.6)
-38.422008637020554

julia> @inferred uf(8.6)
ERROR: return type Float64 does not match inferred return type Any
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope at none:0

This is type stable and non-allocating.

julia> using BenchmarkTools

julia> @benchmark f(8.6)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     2.992 ns (0.00% GC)
  median time:      2.997 ns (0.00% GC)
  mean time:        3.009 ns (0.00% GC)
  maximum time:     14.001 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark $FWf(8.6)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     129.376 ns (0.00% GC)
  median time:      129.572 ns (0.00% GC)
  mean time:        131.515 ns (0.00% GC)
  maximum time:     215.137 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     886

julia> @benchmark $uf(8.6)
BenchmarkTools.Trial: 
  memory estimate:  32 bytes
  allocs estimate:  2
  --------------
  minimum time:     174.538 ns (0.00% GC)
  median time:      176.032 ns (0.00% GC)
  mean time:        194.097 ns (7.44% GC)
  maximum time:     98.510 μs (99.77% GC)
  --------------
  samples:          10000
  evals/sample:     714

Versus full specialization, it’s only about an eighth of a second overhead after a million iterations on this computer, but I’d still prefer a dedicated simulation version that just specializes on the function.

I do think using this looks like a good upgrade for everything currently avoiding specializing on functions by leaving them abstract and doing dynamic dispatches every time. That includes Roots.jl and Optim (https://github.com/JuliaNLSolvers/NLSolversBase.jl/blob/master/src/objective_types/oncedifferentiable.jl) at least.

It would need to recompile more for (eg, SArray{Tuple{3},Float64,1,3} vs SArray{Tuple{4},Float64,1,4}), but I think it’s a good compromise. Changing the input type is going to force a lot of recompilation anyway.


#11

An eighth of a second is too high! :wink: I’m looking at 1 second for my whole model run, something like 100 years at hourly intervals with a couple of roots to find in each step. Then I can run the whole thing 10000 times while I have lunch.

I’m not sure what you mean my specialisation and compile time optimisation exactly, do you mean that Roots.jl is not type stable?

I have ended up just using my own basic solvers, as the changes I need could be total carnage to the current structure of Roots.jl, and I don’t understand the other use cases I would be breaking - as discussed in this issue:

My algorithms might be mediocre but they have basically no overheads at all besides running the function, which more than makes up for it. Type stability >> algorithm!! Simple solvers also handle units and Dual numbers just fine, which is still an issue in some corner cases in Roots.jl (like Dual wrapped in a Unitful.Quantity…)

Instead of FunctionWrapper.jl I just wrap my custom find_zero() in a let block that re-declares the parameters, which seems to be fine.

Edit: I’m happy to help test and contribute to these changes, and make sure they fit some specific use cases, like my Quantity{Dual{Float64}} setup, etc.


#12

The algorithm does not quite work I’m afraid:

julia> f(x) = 1 - x^3
f (generic function with 1 method)

julia> bisection(f,-0.8, 1.8)
(-0.8, -0.7999999999999998)

#13

How is this?

function bisect(fun, xl, xu, tolx, tolf)
    if (xl > xu)
        xl, xu = xu, xl
    end
    fl = fun(xl);
    fu = fun(xu);
    @assert fl*fu < 0.0 "Need to get a bracket"
    if fl == 0.0
        return xl, xl;
    end
    if fu == 0.0
        return xu, xu;
    end
    while true
        xr = (xu + xl) * 0.5 ; # bisect interval
        fr = fun(xr); # value at the midpoint
        if fr * fl < 0.0 # (fr < 0.0 && fl > 0.0) || (fr > 0.0 && fl < 0.0)
            xu, fu = xr, fr;# upper --> midpoint
        else
            xl, fl = xr, fr;# lower --> midpoint
        end
        if (abs(xu-xl) < tolx) || (abs(fr) <= tolf)
            break; # We are done
        end
    end
    return xl, xu;
end

#14

Just another flavor. Probably not type stable or magic.

function custom_bisection(f, cur_a, cur_b, cur_f_a, cur_f_b; abstol=10*eps())
  if isapprox(cur_a, cur_b, atol=abstol)
    isapprox(cur_f_a, 0.0, atol=abstol) && return cur_a
    isapprox(cur_f_b, 0.0, atol=abstol) && return cur_b

    return NaN
  end

  cur_c = ( cur_a + cur_b ) / 2.0
  
  cur_f_c = f(cur_c)

  is_bad_a = isinf(cur_f_a) || isnan(cur_f_a)
  is_bad_b = isinf(cur_f_b) || isnan(cur_f_b)
  is_bad_c = isinf(cur_f_c) || isnan(cur_f_c)

  @assert !is_bad_a || !is_bad_b

  if !is_bad_c
    if !is_bad_a && ( ( cur_f_a * cur_f_c ) <= 0 )
      return custom_bisection(f, cur_a, cur_c, cur_f_a, cur_f_c, abstol=abstol)
    end

    if !is_bad_b && ( ( cur_f_b * cur_f_c ) <= 0 )
      return custom_bisection(f, cur_c, cur_b, cur_f_c, cur_f_b, abstol=abstol)
    end
  end

  if is_bad_a
    return custom_bisection(f, cur_c, cur_b, cur_f_c, cur_f_b, abstol=abstol)
  else
    return custom_bisection(f, cur_a, cur_c, cur_f_a, cur_f_c, abstol=abstol)
  end
end

#15

This should be type stable (but didn’t test it thoroughly):

using Statistics: middle

"""
    bisection(f, a, b; fa = f(a), fb = f(b), ftol, wtol)

Bisection algorithm for finding the root ``f(x) ≈ 0`` within the initial bracket
`[a,b]`.

Returns a named tuple

`(x = x, fx = f(x), isroot = ::Bool, iter = ::Int, ismaxiter = ::Bool)`.

Terminates when either

1. `abs(f(x)) < ftol` (`isroot = true`),
2. the width of the bracket is `≤wtol` (`isroot = false`),
3. `maxiter` number of iterations is reached. (`isroot = false, maxiter = true`).

which are tested for in the above order. Therefore, care should be taken not to make `wtol` too large.

"""
function bisection(f, a::Real, b::Real; fa::Real = f(a), fb::Real = f(b),
                   ftol = √eps(), wtol = 0, maxiter = 100)
    @assert fa * fb ≤ 0 "initial values don't bracket zero"
    @assert isfinite(a) && isfinite(b)
    _bisection(f, float.(promote(a, b, fa, fb, ftol, wtol))..., maxiter)
end

function _bisection(f, a, b, fa, fb, ftol, wtol, maxiter)
    iter = 0
    abs(fa) < ftol && return (x = a, fx = fa, isroot = true, iter = iter, ismaxiter = false)
    abs(fb) < ftol && return (x = b, fx = fb, isroot = true, iter = iter, ismaxiter = false)
    while true
        iter += 1
        m = middle(a, b)
        fm = f(m)
        abs(fm) < ftol && return (x = m, fx = fm, isroot = true, iter = iter, ismaxiter = false)
        abs(b-a) ≤ wtol && return (x = m, fx = fm, isroot = false, iter = iter, ismaxiter = false)
        if fa * fm > 0
            a, fa = m, fm
        else
            b, fb = m, fm
        end
        iter == maxiter && return (x = m, fx = fm, isroot = false, iter = iter, ismaxiter = true)
    end
end

#16

Yeah, I forgot to consider the obvious case a decreasing function. =/ Edited the original post.
Also replaced “bisect” with “middle” – didn’t know about that function before Tamas Papp’s post.

His post is also most complete with all the checks and options you’d want in the Roots library.

@Raf

An eighth of a second is too high! :wink: I’m looking at 1 second for my whole model run, something like 100 years at hourly intervals with a couple of roots to find in each step. Then I can run the whole thing 10000 times while I have lunch.

I’m not sure what you mean my specialisation and compile time optimisation exactly, do you mean that Roots.jl is not type stable?

It’s too high for me too.

What I mean about compile-time optimization is just not having any functions recompile when you change the function you’re root finding. Roots and Optim/NLSolversBase avoid recompilation by having types like this:

# Roots
struct DerivativeFree <: CallableFunction 
    f
end
mutable struct OnceDifferentiable{TF, TDF, TX} <: AbstractObjective
    f # objective
    df # (partial) derivative of objective
    fdf # objective and (partial) derivative of objective
    F::TF # cache for f output
    DF::TDF # cache for df output
    x_f::TX # x used to evaluate f (stored in F)
    x_df::TX # x used to evaluate df (stored in DF)
    f_calls::Vector{Int}
    df_calls::Vector{Int}
end

Note, these all wrap functions with type unstable structs!
When you change the function you’re rootfinding / optimizing, the type of these structs does not change, so nothing recompiles.
Otherwise, if they were parameterized on the function (what I meant by “full specialization”):

struct DerivativeFree{F} <: CallableFunction 
    f::F
end

now the type would change every time you change or update the function, leading to lots of recompiling each time you swap it.

That’s something these authors prioritized, which is why I’d suggested FunctionWrappers.
If you call the same function a million times though, it still only compiles once, so like you I want that as an option too.


#17

Note that in practice, you would almost always want to use Brent’s method for derivative-free univariate rootfinding.

I have one coded up, if there is interest I can wrap it up with some tests.


#18

There are a few issues here, as I see it that I hope I have addressed in recent PRs (one pending):

  • @elrod points out the issue with specialization. The current “tuning” is for smaller times when calling the first time. For multiple calls, a pattern like this (with the recent PR isn’t merged yet that borrows ideas from @elrod’s earlier post):

function fz_nospec_many_calls(f, xs, M)
    state = Roots.init_state(M, f, first(xs))
    options = Roots.init_options(M, state)

    rts = similar(float.(xs))
    for (i, x0) in enumerate(xs)
        Roots.init_state!(state, M, f, x0)
        rts[i] = find_zero(M, f, options, state)
    end

    rts       
end

The find_zero method when called as find_zero(f, x, M) calls @nospecialize(f) on the code path (though this can be avoided now, if desired), but when called as above does not. The init_state! method is new and avoids the overhead there. I’d appreciate any comment on that PR to see if any useful optimizations are left on the table

  • bisection is slow. Yes, but hopefully for a good reason, the default bisection for floating point values is robust to user inputs. It defines the middle differently from a/2 + b/ 2 or a + (b-a)/2 or some such which is sensitive to certain choices of a and b and can be really slow on ,say (0, prevfloat(Inf)) and instead maps two floating point values to unsigned integers and does bisection there. The calls to reinterpret are slower than a simple average, but have the advantage of always producing “exact” answers (where there is a sign change) in no more than 64 steps. The Roots.bisection64 method speeds this up as much as possible, I think. The Roots.bisection method will use a simpler method when the tolerances are non zero, and this should be fast (on reasonable inputs) and is similar to ones suggested here.

  • For @raf’s comment, there was a Roots.secant_method added in response to that question that speeds things up considerably, as it avoids the overhead, but likely isn’t as robust. (The Roots.bisection method was added too.)

  • for @Tamas_Papp comment, the Roots.a42 method should be a faster alternative to Brent’s method. As implemented it usually takes many fewer function calls than bisection, but the performance seems to be slower that expected. It could likely use some work. There is also a set of FalsePosition methods, but unlike bisection convergence isn’t guaranteed.


#19

Do you have a reference I could read about this? Table 5.2 in the APS (1993) article does not suggest a large difference.


#20

Nothing beyond that. I should have written “as fast.” Anyways, the implementation matters too. A PR would be most welcome.