Terminating sum early

I am calculating a sum of values returned by a function, which is supposed to be type stable. The function may be costly, and returns values that are either finite or -Inf. In case of -Inf, I would like to return early.

MWE (which is of course not costly, the first version does not return early):

observation_logdensity(μ, x) = x > μ ? -abs2(x - μ) : oftype(promote(x, μ), -Inf)

sample_logdensity1(μ, xs) = sum(observation_logdensity(μ, x) for x in xs)

Now I am wondering how to program what I want, so that the accumulator does not change type. The function can return a <: Real type (in a type-stable way), especially ForwardDiff.Dual. I would like to have a way to figure out the accumulator type. This version does not, so the type changes inevitably:

function sample_logdensity2(μ, xs)
    total = 0           # in general, I don't know the return type of observation_logdensity
    for x in xs
        total += observation_logdensity(μ, x)
        total == -Inf && return total # terminate early
    end
    total
end

This version should figure it out (if observation_logdensity is type stable), but it is cumbersome:

function sample_logdensity3(μ, xs)
    isempty(xs) && return -Inf
    total = observation_logdensity(μ, first(xs))
    for x in Iterators.drop(xs, 1)
        total += observation_logdensity(μ, x)
        total == -Inf && return total # terminate early
    end
    total
end

I wonder if there is a way to write it in a more compact way idiomatically. xs can be any iterable.

You basically need something like mapreduce with sentinel, right?

mapreduce(observation_logdensity, +, xs; sentinel=(x->x==-Inf))

Is this function hypothetical? I could not find this signature in Base.

You can use a mapreduce with a wrapper of + that errors when one of the terms is not finite, otherwise adds them. Then wrap the whole thing in a try catch block.

About the type of the result, you can try the solution in KissThreading.jl/KissThreading.jl at 372a8a599d8116f09ba7c1702b1dff9169dac95c · mohamed82008/KissThreading.jl · GitHub.

Sorry! Hypotetical… I could not find it either. :confused:

Maybe inspiration for enhancing or for some additional package.

EDIT: This is probably not what you want, give me a sec.

How about

function mapreduceuntil(op, unit, xs; sentinel=(x->false))
    u = iterate(xs)
    u === nothing && error("empty reduction")
    y = unit(u[1])
    while u != nothing
        x, s = u
        sentinel(x) && break
        y = op(y, x)
        u = iterate(xs, s)
    end
    y
end

mapreduceuntil(+, zero, [1.0,2.0,Inf], sentinel=isinf)
1 Like

Then wrap the whole thing in a try catch block.

Nooo… don’t do this please :slight_smile:

2 Likes

Did not look carefully, but I guess you want to stop after overflow and keep Inf as result.

function mapreduceuntil(f, op, unit, xs; sentinel=(x->false), saturation=(y->false))
    u = iterate(xs)
    u === nothing && error("empty reduction")
  
    x, s = f(u[1]), u[2]
    y = unit(x)
    sentinel(x) && return y
    saturation(y) && return y
    y = op(x, y)
    u = iterate(xs, s)

    while u != nothing
        x, s = f(u[1]), u[2]        
        sentinel(x) && break
        saturation(y) && break
        y = op(x, y)
        u = iterate(xs, s)
    end
    y
end

mapreduceuntil(identity, +, zero, [1.0,2.0,Inf, Inf], saturation=isinf)

There are two separate issues here. One is to terminate a reduction early, the other is to write type-stable code that is not cumbersome.

I think Julia could reallyt from a @typeof macro that would compute the return type of an expression without actually evaluating it. So one could write total = zero(@typeof(observation_logdensity(μ, first(xs)))) and that would make sample_logdensity2 as type-stable and as fast as sample_logdensity3

I’m not sure if it’s possible to create such a macro in an efficient way using Base.return_types, or if it would require some extra magic.

How about this? I think it will always iterate through xs, but won’t call f once it’s hit -Inf:

function mysum(f::Function, xs)
    flag::Bool=true
    setflag(y) = begin y==-Inf && (flag=false); y end
    sum(setflag(f(x)) for x in xs if flag)
end

f5(x) = begin @show(x); x>5 ? -Inf : Float64(x) end

mysum(f5, 1:20) ## x = 1.0, x = 2.0, ..., x = 6.0; return -Inf

Very clever, thanks. But I am always unsure about generators and scoping, is this allowed?

Sorry to bother with nomenclature but your function is more reduce than mapreduce.

And Tamas needs also applying observation_logdensity to every element before sum it.

As doc says:

mapreduce(f, op, A::AbstractArray; dims=:, [init])

  Evaluates to the same as reduce(op, map(f, A); dims=dims, init=init), but is generally faster because the intermediate array is avoided.

So with your function Tamas has to write something like:

mysum(μ, xs) = mapreduceuntil(+, zero,map((x->observation_logdensity(μ, x), xs), sentinel=isinf)
mysum(μ, [1.0,2.0,Inf])

…which is generally slower because the intermediate array is needed.

There is problem that setflag depends on partial sum and not on element.

Isn’t doc is a little misleading?

But is it true? Couldn’t we write it as:

reduce(op, f(i) for i in A; dims=dims, init=init)

Benchmark suggests that it is better way:

julia> @btime reduce(+, i*i for i in [1,2,3])
  36.302 ns (2 allocations: 128 bytes)
14

julia> @btime reduce(+, map(i->i*i, [1,2,3]))
  65.933 ns (3 allocations: 240 bytes)
14

although mapreduce still looks better:

julia> @btime mapreduce(i->i*i, +, [1,2,3])
  31.027 ns (1 allocation: 112 bytes)
14