Memoized inverse function through bisection

Hey,

I am trying to implement some memoized version of a inversion algorithm by bisection, but i currently hit a stackoverflow and i cannot understand why :

struct MemoizedInverse{Tf,Tstore}
    f::Tf
    store::Tstore
    function MemoizedInverse(f,lx,rx)
        store = Dict(
            x => f(x) for x in promote(lx,rx)
        )
        return new{typeof(f),typeof(store)}(f,store)
    end
end
function get_val(X::MemoizedInverse,x)
    if x ∈ keys(X.store)
        return X.store[x]
    else
        y = X.f(x)
        X.store[x] = y
        return y
    end
end
function(X::MemoizedInverse)(y)
    return X(y,minimum(keys(X.store)),maximum(keys(X.store)))
end
function (X::MemoizedInverse)(y, lx, rx)

    @show lx,rx
    # check middle point: 
    mx = (lx + rx)/2
    tol = 0.1#sqrt(eps(float(eltype(y))))
    if abx(lx - mx) < tol
        return mx
    end

    ly = get_val(X,lx)
    ry = get_val(X,rx)
    increasing = ly <= ry

    y_in = increasing ? (ly <= y <= ry) : (ry <= y <= ly)
    if !y_in
        throw(ArgumentError("$y is not inside [$ly,$ry] and thus [$lx, $rx] is not a valid bracketing interval"))
    end

    # recurse: 
    my = get_val(X,mx)
    if (increasing && (y > my)) || (!increasing && (y < my))
        return X(y,mx,rx)
    else
        return X(y,lx,mx)
    end
end



f(x) = exp(-x)
finv =   MemoizedInverse(f,0,Inf)

# finv(10)
finv(rand())

Moreover, I am not sure this Dict approach is a cleanest, maybe a tree structure would be more performant ? My first goal is tomake it working, but then the objective is performance ^^

Edit: updated the code a bit but still same issue. even with absurdly large tollerences I still hit the stackoverflow

I think, you get an overflow because your right border is Inf and thus your midpoint is also Inf so you never actually narrow your window down.

And yes I think for performance you will need a tree data structure for sure. While Dict has constant time lookup it has actually a rather large constant factor as hashing is somewhat expensive. If you instead use some tree data structure to narrow down the interval you can be much quicker I think.

You are right of course, I am staying at Inf… Which is stupid.

Thanks, I’ll try to first fix this and then I’ll take a look at a better data structure. Do you have an idea on where i can find that ?

I am probably not the best guy to ask since I am not a computer scientist and this is probably a “solved” problem.

I would propose just a simple binary tree, where each node stores the function value at a point (which is the center point of some interval) and has upto two child nodes. The root node contains the value at the middlepoint of your initial search interval. For searching you just do simple comparison with the function value to figure out whether the value you seek is to the left or the right of the center and then walk to corresponding child node (if that does not exist, create it and save for future use). So (assuming everything is precomputed) at every level you essentially just perform a comparison and then chase a pointer.

I am not sure whether there are good implementations that one can just reuse. DataStructures.jl seems to only feature more sophisticated tree data structures. But a simple implementation of this algorithm shouldn’t be too hard so an external dependency is maybe unnecessary.

DO NOT use a simple binary-tree-based structure. It’s not performant due to poor locality. Implementing a performant dynamic sorted container is hard but you DO NOT need it. For functions like e^(x^2), looking up is probably not faster than computing. (Computing is fast, memory is slow…). So, storing a decision tree statically or dynamically wouldn’t be my recommendation. I know this sounds unintuitive but looking up something like e^(x^2) from memory is probably slower than computing it, especially with a large tree. If you compute the inverse of A SINGLE VALUE multiple times, you can memoize the function. This memoization might be faster or slower. You need to test it out yourself. Also, you can try out precomputing an array of bounds that you start out by rounding. (For example, if you want f^-1(1.22), you look up f^-1(1.2) and f^-1(1.3) then do a bisection search further.

However, if what you want is a fast, generic inverse function, please note that Julia has a good automatic differentiation ecosystem, and you might be able to implement Newton-Raphson without much loss of generality. (Though be sure to test out because Newton-Raphson may fail to converge).

1 Like

This is indeed aninteresting comment. Checking the value sin memory might be longer than computing them… except if the computaiton of f itself is really demanding (which might well be the case in my applications).

However, I knwo in my application that f goes from [0,Inf] to [0,1], somaybe i need to store f-1(0.1), f-1(0.2) … up to f-1(0.9) so that i can divide in ten and then bisect ? Or in hundred and then bisect ? Or in N, N choosable by the user ? I will test a few things an,d come back

So @Tarny_GG_Channie says, that essentially your route (bisection + memoization) should only be used if your function evaluation is very expensive and can not be differentiated, which I agree with.

For your interval problem: Instead of inverting f(x)=exp(-x) on [0,Inf] you could just switch variables to e.g. y(x)=1/(x+1) then you can to invert f(y(x)) on [0,1] and map back to your original interval via x(y) = 1/y -1

By the way: Is there a reason you are implementing this yourself and don’t want to use one of Julia’s optimization or root-finding libraries? (Function inversion is essentially finding the value x such that g(x) = abs(f(x)-y) is minimal or finding the root of f(x)-y.)

Well I am currently using Roots.jl for that, but since the function might be expensive and number of calls will be huge, I was wandering if i could do better since i “know” a bit about the function

Your approach here also does not use additional knowledge about the function, just caching, right? You can probably achieve close to the same thing be using a cache for the functions you want to invert in conjunction with Roots.jl.

In any case unless you have seen that the number of function evaluation is indeed a bottleneck, you probably should not spend the effort to try and optimize them :slightly_smiling_face:

Premature optimization is the root of all evil.
~ Donald Knuth

1 Like

Is this a continuation of this topic? If so, your function is a CDF, just vertically translated by a constant, so you could use the fact that it’s nondecreasing to shave off an instruction or two from the hot loop in the bisection method code. That could maybe be a performance improvement, if the evaluation of the function is cheap enough.

It is the same project but not the same function, this one is not a CDF :).

Although now that i think of it, this one is a survival function and thus i could use the previous code… on 1-f

1 Like