# Terminating sum early

I am calculating a sum of values returned by a function, which is supposed to be type stable. The function may be costly, and returns values that are either finite or -Inf. In case of -Inf, I would like to return early.

MWE (which is of course not costly, the first version does not return early):

observation_logdensity(μ, x) = x > μ ? -abs2(x - μ) : oftype(promote(x, μ), -Inf)

sample_logdensity1(μ, xs) = sum(observation_logdensity(μ, x) for x in xs)


Now I am wondering how to program what I want, so that the accumulator does not change type. The function can return a <: Real type (in a type-stable way), especially ForwardDiff.Dual. I would like to have a way to figure out the accumulator type. This version does not, so the type changes inevitably:

function sample_logdensity2(μ, xs)
total = 0           # in general, I don't know the return type of observation_logdensity
for x in xs
total += observation_logdensity(μ, x)
end
total
end


This version should figure it out (if observation_logdensity is type stable), but it is cumbersome:

function sample_logdensity3(μ, xs)
isempty(xs) && return -Inf
total = observation_logdensity(μ, first(xs))
for x in Iterators.drop(xs, 1)
total += observation_logdensity(μ, x)
end
total
end


I wonder if there is a way to write it in a more compact way idiomatically. xs can be any iterable.

You basically need something like mapreduce with sentinel, right?

mapreduce(observation_logdensity, +, xs; sentinel=(x->x==-Inf))


Is this function hypothetical? I could not find this signature in Base.

You can use a mapreduce with a wrapper of + that errors when one of the terms is not finite, otherwise adds them. Then wrap the whole thing in a try catch block.

Sorry! Hypotetical… I could not find it either.

Maybe inspiration for enhancing or for some additional package.

EDIT: This is probably not what you want, give me a sec.

function mapreduceuntil(op, unit, xs; sentinel=(x->false))
u = iterate(xs)
u === nothing && error("empty reduction")
y = unit(u[1])
while u != nothing
x, s = u
sentinel(x) && break
y = op(y, x)
u = iterate(xs, s)
end
y
end

mapreduceuntil(+, zero, [1.0,2.0,Inf], sentinel=isinf)

1 Like

Then wrap the whole thing in a try catch block.

2 Likes

Did not look carefully, but I guess you want to stop after overflow and keep Inf as result.

function mapreduceuntil(f, op, unit, xs; sentinel=(x->false), saturation=(y->false))
u = iterate(xs)
u === nothing && error("empty reduction")

x, s = f(u[1]), u[2]
y = unit(x)
sentinel(x) && return y
saturation(y) && return y
y = op(x, y)
u = iterate(xs, s)

while u != nothing
x, s = f(u[1]), u[2]
sentinel(x) && break
saturation(y) && break
y = op(x, y)
u = iterate(xs, s)
end
y
end

mapreduceuntil(identity, +, zero, [1.0,2.0,Inf, Inf], saturation=isinf)


There are two separate issues here. One is to terminate a reduction early, the other is to write type-stable code that is not cumbersome.

I think Julia could reallyt from a @typeof macro that would compute the return type of an expression without actually evaluating it. So one could write total = zero(@typeof(observation_logdensity(μ, first(xs)))) and that would make sample_logdensity2 as type-stable and as fast as sample_logdensity3

I’m not sure if it’s possible to create such a macro in an efficient way using Base.return_types, or if it would require some extra magic.

How about this? I think it will always iterate through xs, but won’t call f once it’s hit -Inf:

function mysum(f::Function, xs)
flag::Bool=true
setflag(y) = begin y==-Inf && (flag=false); y end
sum(setflag(f(x)) for x in xs if flag)
end

f5(x) = begin @show(x); x>5 ? -Inf : Float64(x) end

mysum(f5, 1:20) ## x = 1.0, x = 2.0, ..., x = 6.0; return -Inf


Very clever, thanks. But I am always unsure about generators and scoping, is this allowed?

Sorry to bother with nomenclature but your function is more reduce than mapreduce.

And Tamas needs also applying observation_logdensity to every element before sum it.

As doc says:

mapreduce(f, op, A::AbstractArray; dims=:, [init])

Evaluates to the same as reduce(op, map(f, A); dims=dims, init=init), but is generally faster because the intermediate array is avoided.


So with your function Tamas has to write something like:

mysum(μ, xs) = mapreduceuntil(+, zero,map((x->observation_logdensity(μ, x), xs), sentinel=isinf)
mysum(μ, [1.0,2.0,Inf])


…which is generally slower because the intermediate array is needed.

There is problem that setflag depends on partial sum and not on element.

Isn’t doc is a little misleading?

But is it true? Couldn’t we write it as:

reduce(op, f(i) for i in A; dims=dims, init=init)


Benchmark suggests that it is better way:

julia> @btime reduce(+, i*i for i in [1,2,3])
36.302 ns (2 allocations: 128 bytes)
14

julia> @btime reduce(+, map(i->i*i, [1,2,3]))
65.933 ns (3 allocations: 240 bytes)
14


although mapreduce still looks better:

julia> @btime mapreduce(i->i*i, +, [1,2,3])
31.027 ns (1 allocation: 112 bytes)
14