High precision summation functions not compatible with Optim.jl with autodiff=:forward

I am doing high dimensional Monte Carlo integrations in a maximum likelihood estimation using Optim.jl with the autodiff = :forward option (based on ForwardDiff.jl). I use up to millions of quasi-random draws (e.g., Halton sequence and the like) for the numerical integration and thus require high precision summations for adding up the simulated numbers. However, current implementations of most of the high precision summation functions do not work with ForwardDiff, though I suppose they are fixable. I have limited knowledge of ForwardDiff and the related type constranits, thus I am asking for advice here.

I am considering the following functions (listed from slow to fast):

  1. KahanSummation.sum_kbn(): generic Julia code.
  2. psum_kbn(): a multi-threading extension of KahanSummation.sum_kbn() proposed by @goerch from this thread.
  3. AccurateArithmetic.sum_kbn() and AccurateArithmetic.sum_oro(): an implementation using llvmcall, which is close to the speed of the regular sum.

Currently, only KahanSummation.sum_kbn(), the slowest, is compatible with ForwardDiff.jl. According to the doc, ForwardDiff requires that the target function be composed of generic Julia functions. I am sure psum_kbn() is generic Julia, and the problem is in type conversions. I am less certain of whether the other two functions from AccurateArithmetic are considered ā€œgeneric Juliaā€ by ForwardDiff.

Here is a MWE.

using StatsFuns, Distributions, HaltonSequences, Optim
using KahanSummation, AccurateArithmetic

data = [-0.755, -1.710, -0.891, -2.889, -0.881, 0.763, -0.365]
draws = 200 # potentially in millions for high dimension problems
halton = Halton(length=draws)  

function LLexample(algo, log_Ļƒįµ¤Ā², log_Ļƒįµ„Ā², e, LDSdraw)  # algo: summation algorithm
    Ļƒįµ„ = exp(0.5*log_Ļƒįµ„Ā²)
    Ļƒįµ¤ = exp(0.5*log_Ļƒįµ¤Ā²)
    u  = quantile(Normal(0, Ļƒįµ¤), 0.5 * LDSdraw .+ 0.5) 

    loglike = 0.0
    for i = 1:length(e)  
        temp = normpdf.(0, Ļƒįµ„, e[i] .+ u) 
        loglike += log(algo(temp)/length(LDSdraw))  # algo: summation algorithm
    end
     return -loglike
 end  

Using AccurateArithmetic.sum_kbn (or similarly, AccurateArithmetic.sum_oro`), I got the following errors.

julia> func = TwiceDifferentiable(vars -> LLexample(AccurateArithmetic.sum_kbn, vars[1], vars[2], data, halton), ones(2);  autodiff = :forward);

julia> res = optimize(func, [-0.1, -0.1], Newton())
ERROR: MethodError: no method matching accumulate(::Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#17#18", Float64}, Float64, 2}}}, ::AccurateArithmetic.Summation.var"#1#2"{typeof(AccurateArithmetic.EFT.fast_two_sum)}, ::Val{:scalar}, ::Val{2}, ::Val{0})
Closest candidates are:
  accumulate(::Tuple{Vararg{AbstractArray{T}, A}}, ::F, ::Any, ::Val{Ushift}, ::Val{Prefetch}) where {F, A, T<:Union{Float32, Float64}, Ushift, Prefetch} at C:\Users\King\.julia\packages\AccurateArithmetic\449tA\src\Summation.jl:25
  accumulate(::Tuple{Vararg{AbstractArray{T}, A}}, ::F, ::Any, ::Val{Ushift}) where {F, A, T<:Union{Float32, Float64}, Ushift} at C:\Users\King\.julia\packages\AccurateArithmetic\449tA\src\Summation.jl:25      
  accumulate(::Tuple{Vararg{AbstractArray{T}, A}}, ::F, ::Any) where {F, A, T<:Union{Float32, Float64}} at C:\Users\King\.julia\packages\AccurateArithmetic\449tA\src\Summation.jl:25

To test psum_kbn(), which is a multi-threading extension of KahanSummationā€™s sum_kbn(), the patch KahanSummation_patch.jl may be needed, which can be downloaded here; see also the discussion here.

using InitialValues, Folds
include("KahanSummation_patch.jl") # for psum_kbn()

psum_kbn(f, X) = singleprec(Folds.mapreduce(f, InitialValues.asmonoid(plus_kbn), X)) # credit to @goerch
psum_kbn(X) = psum_kbn(identity, X)

Then,

julia> func = TwiceDifferentiable(vars -> LLexample(psum_kbn, vars[1], vars[2], data, halton), ones(2);  autodiff = :forward);

julia> res = optimize(func, [-0.1, -0.1], Newton())
ERROR: MethodError: convert(::Type{ForwardDiff.Dual{ForwardDiff.Tag{var"#15#16", Float64}, Float64, 2}}, ::TwicePrecisionN{ForwardDiff.Dual{ForwardDiff.Tag{var"#15#16", Float64}, Float64, 2}}) is ambiguous. Candidates:
  convert(::Type{T}, x::TwicePrecisionN) where T in Main at e:\temp\KahanSummation_patch.jl:74
  convert(::Type{ForwardDiff.Dual{T, V, N}}, x) where {T, V, N} in ForwardDiff at C:\Users\King\.julia\packages\ForwardDiff\PBzup\src\dual.jl:384
Possible fix, define
  convert(::Type{ForwardDiff.Dual{T, V, N}}, ::TwicePrecisionN) where 
{T, V, N}

Suggestions to fix the problems will be appreciated.

I believe

looks like a problem in the patch. If I change singleprec to

singleprec(x::TwicePrecisionN{T}) where {T} = x.hi - x.nlo

the optimization seems to work.

Another remark: using KahanSummation and including the patch could be redundant.

looks like a problem in the patch. If I change singleprec to

singleprec(x::TwicePrecisionN{T}) where {T} = x.hi - x.nlo

It works, thanks a lot! With psum_kbn(), the simulation program now uses around 90% of the CPU. It runs very fast.

Another remark: using KahanSummation and including the patch could be redundant.

I keep only the added codes in KahanSummation_patch.jl (thus ā€œpatchā€) and leave the original functions in the package loaded by using KahanSummation. Itā€™s true that we could keep everything together in a single package.

Couldnā€™t all of these be made to work with AD by defining the appropriate rule via DiffRules.jl (for ForwardDiff) and/or ChainRules.jl (for everything else)?

PS. Thereā€™s also Xsum.jl, which does exactly rounded summation, albeit in double precision only via an external C library. Again, you would need to define a rule for AD to work on this.

2 Likes

Thanks for helping. I used something similar to @HJW019ā€™s test case from here to compare

using KahanSummation, Xsum, Distributions, Random

xs = map(x -> 10.0^clamp(x, -35, 35), rand(Normal(0.0, 35.0), 1_000_000))
ys = -xs
zs  = shuffle(vcat(xs, ys))

println(sum(zs))
println(sum_kbn(zs))
println(xsum(zs))
println(sum(BigFloat.(zs)))

yielding for example

4.722366482869645e21
4.02653184e8
0.0
3.039467936904015380101779459404493157660519265569623764721481767869804002657474e-38

This could be useful indeed.

AFAIU KahanSummation should work with AD without defining additional rules, although @elrod mentioned that could be an optimization option. This patch however would additionally allow easy parallelization via Folds.

Edit: I obviously forgot to mention the performance regression in KahanSummation on Julia 1.7.1.

Regarding AccurateArithmetics you are right AFAIU.

I think ā€œgenericā€ here means: ā€œcan work with arbitrary types of numbers (including ForwardDiff.Dual)ā€. And as you found out, compensated algorithms implemented in AccurateArithmetic.jl are specialized for floating-point numbers (couldnā€™t see any other way to have efficient vectorization). And apart from the fact that itā€™s vectorized, AccurateArithmetic.sum_kbn should have no benefit over KahanSummation.sum_kbn, so that itā€™s probably good to use the latter if it composes well with ForwardDiff.

If, however, you define appropriate rules for AD to work on sum_kbn, then there might be an interest in using AccurateArithmetic for performance reasons, or Xsum for more accuracy.


Youā€™re probably already well aware of this, but having zs sum to exactly 0 makes it a very difficult test case, because the condition number of the summation of elements in zs is ā€œinfiniteā€ (in other words, there is no bound on the impact that a single round-off can have on the relative error for the overall computation).

In any case, if this is representative of real data(*), then exact summation algorithms (such as Xsum) are IMO the only way to get sensible results.


(*) for a real world criterion, let's say, if the summation has a condition number in the order of more than 1e15 - 1e30 (depending on the accuracy you want to get int the end)
2 Likes

Iā€™m not sure how to do it with DiffRules.jl, but extending AccurateArithmetic so that it can handle vectors of ForwardDiff.Dual numbers seems doable.

In the simple attempt below, the results seem to be correct but Iā€™m copying the values to a temporary vector in order to apply vectorized algorithms from AccurateArithmetic on the vector. I suspect this will kill the performance in a lot of cases. Also, Iā€™m using naive summation to sum the partial derivatives: I expect this sum to be well conditioned, but maybe this assumption is not true.

Anyways, maybe this works in the real context:

using ForwardDiff
using AccurateArithmetic

function AccurateArithmetic.Summation.accumulate((xdual,)::Tuple{<:AbstractVector{T}}, args...) where {T<:ForwardDiff.Dual}
    values   = getfield.(xdual, :value)
    value    = AccurateArithmetic.Summation.accumulate((values,), args...)
    partials = sum(xd.partials for xd in xdual)
    T(value, partials)
end
julia> x = rand(10);

julia> g1 = ForwardDiff.gradient(sum, x)
10-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0

julia> g2 = ForwardDiff.gradient(sum_oro, x)
10-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0

Iā€™m not sure what would be the correct way to move forward if we want the computation to be as efficient as possible.

2 Likes

Cool! Thank you for helping. The following modified test

include("KahanSummation_patch.jl")

using Random, BenchmarkTools, InitialValues, Folds

psum_kbn(f, X) = singleprec(Folds.mapreduce(f, InitialValues.asmonoid(plus_kbn), X))
psum_kbn(X) = psum_kbn(identity, X)

using StatsFuns, Distributions, HaltonSequences, Optim
using AccurateArithmetic, ForwardDiff

function AccurateArithmetic.Summation.accumulate((xdual,)::Tuple{<:AbstractVector{T}}, args...) where {T<:ForwardDiff.Dual}
    values   = getfield.(xdual, :value)
    value    = AccurateArithmetic.Summation.accumulate((values,), args...)
    partials = sum(xd.partials for xd in xdual)
    T(value, partials)
end

const data = [-0.755, -1.710, -0.891, -2.889, -0.881, 0.763, -0.365]
const draws = 1_000_000 # potentially in millions for high dimension problems
const halton = Halton(length=draws)  

function LLexample(algo, log_Ļƒįµ¤Ā², log_Ļƒįµ„Ā², e, LDSdraw)  # algo: summation algorithm
    Ļƒįµ„ = exp(0.5*log_Ļƒįµ„Ā²)
    Ļƒįµ¤ = exp(0.5*log_Ļƒįµ¤Ā²)
    u  = quantile(Normal(0, Ļƒįµ¤), 0.5 * LDSdraw .+ 0.5) 

    loglike = 0.0
    for i = 1:length(e)  
        temp = normpdf.(0, Ļƒįµ„, e[i] .+ u) 
        loglike += log(algo(temp)/length(LDSdraw))  # algo: summation algorithm
    end
    return -loglike
end  

function test1()
    func = TwiceDifferentiable(vars -> LLexample(sum, vars[1], vars[2], data, halton), ones(2);  autodiff = :forward);
    res = @btime optimize($func, [-0.1, -0.1], Newton())
    # println(res)
end 

function test2()
    func = TwiceDifferentiable(vars -> LLexample(AccurateArithmetic.sum_kbn, vars[1], vars[2], data, halton), ones(2);  autodiff = :forward);
    res = @btime optimize($func, [-0.1, -0.1], Newton())
    # println(res)
end

function test3()
    func = TwiceDifferentiable(vars -> LLexample(sum_kbn, vars[1], vars[2], data, halton), ones(2);  autodiff = :forward);
    res = @btime optimize($func, [-0.1, -0.1], Newton())
    # println(res)
end

function test4()
    func = TwiceDifferentiable(vars -> LLexample(psum_kbn, vars[1], vars[2], data, halton), ones(2);  autodiff = :forward);
    res = @btime optimize($func, [-0.1, -0.1], Newton())
    # println(res)
end

test1(); test2(); test3(); test4();

seems to work and yields

  6.879 s (397 allocations: 4.14 GiB)
  13.252 s (252000999 allocations: 12.80 GiB)
  6.838 s (397 allocations: 4.14 GiB)
  6.679 s (13169 allocations: 4.14 GiB)

Interestingly I donā€™t see speedup through parallelized Kahan summation for this data set.

Donā€™t you want to sum the partials using compensated summation as well?

1 Like

Correct. Trying

function AccurateArithmetic.Summation.accumulate((xdual,)::Tuple{<:AbstractVector{T}}, args...) where {T<:ForwardDiff.Dual}
    values   = getfield.(xdual, :value)
    value    = AccurateArithmetic.Summation.accumulate((values,), args...)
    partials = AccurateArithmetic.sum_kbn(xd.partials for xd in xdual)
    T(value, partials)
end

runs into the next problem

ERROR: MethodError: no method matching default_ushift(::Base.Generator{Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 2}}, var"#5#6"}, ::AccurateArithmetic.Summation.var"#1#2"{typeof(AccurateArithmetic.EFT.fast_two_sum)})