Chi2 (chisq) fit non-convergence using Optimization.jl

using CairoMakie, FHist
using Optimization, OptimizationOptimJL, OptimizationMultistartOptimization

hist_mjj12_rosa = Hist1D(; binedges=0:2:1000,
    bincounts=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 633.0, 588.0, 649.0, 564.0, 588.0, 586.0, 533.0, 489.0, 495.0, 465.0, 464.0, 462.0, 468.0, 419.0, 426.0, 432.0, 408.0, 343.0, 362.0, 316.0, 346.0, 301.0, 323.0, 294.0, 296.0, 269.0, 286.0, 285.0, 250.0, 259.0, 256.0, 247.0, 242.0, 227.0, 217.0, 230.0, 221.0, 208.0, 200.0, 189.0, 183.0, 191.0, 172.0, 166.0, 164.0, 134.0, 164.0, 160.0, 162.0, 120.0, 147.0, 135.0, 129.0, 114.0, 127.0, 117.0, 104.0, 112.0, 112.0, 112.0, 92.0, 106.0, 101.0, 99.0, 106.0, 84.0, 78.0, 83.0, 94.0, 81.0, 84.0, 71.0, 88.0, 92.0, 81.0, 62.0, 87.0, 51.0, 74.0, 66.0, 73.0, 80.0, 73.0, 59.0, 57.0, 62.0, 61.0, 75.0, 49.0, 54.0, 56.0, 59.0, 64.0, 58.0, 55.0, 50.0, 44.0, 47.0, 43.0, 39.0, 42.0, 46.0, 38.0, 41.0, 36.0, 39.0, 37.0, 39.0, 42.0, 37.0, 34.0, 27.0, 40.0, 41.0, 22.0, 26.0, 37.0, 31.0, 28.0, 30.0, 29.0, 22.0, 36.0, 38.0, 30.0, 26.0, 26.0, 11.0, 21.0, 25.0, 24.0, 26.0, 21.0, 18.0, 25.0, 15.0, 19.0, 23.0, 23.0, 20.0, 14.0, 11.0, 19.0, 9.0, 25.0, 18.0, 12.0, 19.0, 13.0, 6.0, 10.0, 14.0, 12.0, 9.0, 22.0, 11.0, 12.0, 13.0, 7.0, 18.0, 10.0, 19.0, 12.0, 16.0, 9.0, 14.0, 12.0, 14.0, 7.0, 11.0, 13.0, 15.0, 11.0, 8.0, 14.0, 2.0, 10.0, 13.0, 9.0, 13.0, 10.0, 10.0, 8.0, 15.0, 8.0, 4.0, 6.0, 8.0, 8.0, 3.0, 7.0, 6.0, 10.0, 5.0, 7.0, 4.0, 4.0, 2.0, 13.0, 3.0, 9.0, 8.0, 5.0, 3.0, 3.0, 8.0, 8.0, 6.0, 5.0, 3.0, 2.0, 6.0, 2.0, 1.0, 3.0, 5.0, 1.0, 4.0, 5.0, 5.0, 8.0, 2.0, 6.0, 0.0, 5.0, 7.0, 5.0, 3.0, 3.0, 6.0, 5.0, 2.0, 3.0, 3.0, 1.0, 1.0, 3.0, 2.0, 7.0, 2.0, 2.0, 3.0, 5.0, 2.0, 2.0, 2.0, 5.0, 6.0, 3.0, 3.0, 2.0, 5.0, 2.0, 4.0, 2.0, 4.0, 2.0, 2.0, 2.0, 3.0, 2.0, 3.0, 1.0, 1.0, 4.0, 6.0, 5.0, 4.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 3.0, 1.0, 4.0, 0.0, 1.0, 1.0, 3.0, 1.0, 2.0, 1.0, 3.0, 1.0, 1.0, 4.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 4.0, 2.0, 1.0, 2.0, 3.0, 3.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 0.0, 0.0, 1.0, 3.0, 0.0, 1.0, 2.0, 2.0, 3.0, 2.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 3.0, 1.0, 1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 2.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    sumw2=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 633.0, 588.0, 649.0, 564.0, 588.0, 586.0, 533.0, 489.0, 495.0, 465.0, 464.0, 462.0, 468.0, 419.0, 426.0, 432.0, 408.0, 343.0, 362.0, 316.0, 346.0, 301.0, 323.0, 294.0, 296.0, 269.0, 286.0, 285.0, 250.0, 259.0, 256.0, 247.0, 242.0, 227.0, 217.0, 230.0, 221.0, 208.0, 200.0, 189.0, 183.0, 191.0, 172.0, 166.0, 164.0, 134.0, 164.0, 160.0, 162.0, 120.0, 147.0, 135.0, 129.0, 114.0, 127.0, 117.0, 104.0, 112.0, 112.0, 112.0, 92.0, 106.0, 101.0, 99.0, 106.0, 84.0, 78.0, 83.0, 94.0, 81.0, 84.0, 71.0, 88.0, 92.0, 81.0, 62.0, 87.0, 51.0, 74.0, 66.0, 73.0, 80.0, 73.0, 59.0, 57.0, 62.0, 61.0, 75.0, 49.0, 54.0, 56.0, 59.0, 64.0, 58.0, 55.0, 50.0, 44.0, 47.0, 43.0, 39.0, 42.0, 46.0, 38.0, 41.0, 36.0, 39.0, 37.0, 39.0, 42.0, 37.0, 34.0, 27.0, 40.0, 41.0, 22.0, 26.0, 37.0, 31.0, 28.0, 30.0, 29.0, 22.0, 36.0, 38.0, 30.0, 26.0, 26.0, 11.0, 21.0, 25.0, 24.0, 26.0, 21.0, 18.0, 25.0, 15.0, 19.0, 23.0, 23.0, 20.0, 14.0, 11.0, 19.0, 9.0, 25.0, 18.0, 12.0, 19.0, 13.0, 6.0, 10.0, 14.0, 12.0, 9.0, 22.0, 11.0, 12.0, 13.0, 7.0, 18.0, 10.0, 19.0, 12.0, 16.0, 9.0, 14.0, 12.0, 14.0, 7.0, 11.0, 13.0, 15.0, 11.0, 8.0, 14.0, 2.0, 10.0, 13.0, 9.0, 13.0, 10.0, 10.0, 8.0, 15.0, 8.0, 4.0, 6.0, 8.0, 8.0, 3.0, 7.0, 6.0, 10.0, 5.0, 7.0, 4.0, 4.0, 2.0, 13.0, 3.0, 9.0, 8.0, 5.0, 3.0, 3.0, 8.0, 8.0, 6.0, 5.0, 3.0, 2.0, 6.0, 2.0, 1.0, 3.0, 5.0, 1.0, 4.0, 5.0, 5.0, 8.0, 2.0, 6.0, 0.0, 5.0, 7.0, 5.0, 3.0, 3.0, 6.0, 5.0, 2.0, 3.0, 3.0, 1.0, 1.0, 3.0, 2.0, 7.0, 2.0, 2.0, 3.0, 5.0, 2.0, 2.0, 2.0, 5.0, 6.0, 3.0, 3.0, 2.0, 5.0, 2.0, 4.0, 2.0, 4.0, 2.0, 2.0, 2.0, 3.0, 2.0, 3.0, 1.0, 1.0, 4.0, 6.0, 5.0, 4.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0, 3.0, 1.0, 4.0, 0.0, 1.0, 1.0, 3.0, 1.0, 2.0, 1.0, 3.0, 1.0, 1.0, 4.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 4.0, 2.0, 1.0, 2.0, 3.0, 3.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 0.0, 0.0, 1.0, 3.0, 0.0, 1.0, 2.0, 2.0, 3.0, 2.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 3.0, 1.0, 1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 2.0, 2.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
)

hist_mjj12_rosa = hist_mjj12_rosa |> restrict(130, 500)

# reference fitting parameter
p0_rosa = [1.0, 16.0608, 8.10131, 0.794944, 0.0223786]
function f_rosa(x, p) 
    @. p[1]*(1-x)^16.0608 * x^(-(8.10131 + 0.794944*log(x) + 0.0223786*abs2(log(x))))
end


const sqrt_s = 13000
@. ATLAS_f_5p(x, p) = p[1]*(1-x)^p[2] * x^(-(p[3] + p[4]*log(x) + p[5]*abs2(log(x))))

# normalized x-axis to 0 - 1 for the fitting formula to work
xs = bincenters(hist_mjj12_mask)/sqrt_s
xs_mjj = bincenters(hist_mjj12_mask)

# chi2 loss with variance term
function chi2(os, cs, σs)
    sum(@. abs2(os - cs)/abs2(σs))
end

function optim_fit(func, hist0, p0)
    hist = normalize(hist0, width=false)
    xs = bincenters(hist) ./ sqrt_s
    ys_truth = bincounts(hist)
    σs = binerrors(hist)
    F = function (p, _)
        cs = func(xs, p)
        return chi2(ys_truth, cs, σs)
    end
    optf = OptimizationFunction(F, AutoForwardDiff())
    prob = OptimizationProblem(optf, p0; lb = zero(p0), ub=zero(p0) .+ 100)
    return solve(prob, MultistartOptimization.TikTak(100), BFGS(); maxiters=4000)
end

sol1 = optim_fit(f_rosa, hist_mjj12_rosa, ones(1))
ys_rosap1 = f_rosa.(xs, Ref(sol1.u))
ys_rosap1 *= (integral(hist_mjj12_rosa) / sum(ys_rosap1))

sol5 = optim_fit(ATLAS_f_5p, hist_mjj12_rosa, [sol1.u; p0_rosa[2:end]])
ys_p5_raw = ATLAS_f_5p.(xs, Ref(sol5.u))
"""
output:
retcode: Default
u: 5-element Vector{Float64}:
 64.84375
 27.34375
 19.53125
 46.09375
  2.34375
"""

ys_p5 = ys_p5_raw * (integral(hist_mjj12_rosa) / sum(ys_p5_raw))


ys_rosa_input = bincounts(hist_mjj12_rosa) 
@show chi2(ys_rosa_input, ys_rosap1, σs_rosa)
@show chi2(ys_rosa_input, ys_p5, σs_rosa)

output:

chi2(ys_rosa_input, ys_rosap1, σs_rosa) = 273.22202140699414
chi2(ys_rosa_input, ys_p5, σs_rosa) = 2.2672859805143725e7

from inspection, clearly the reference parameters (p0_rosa) fits the data up to the normalization constant which we re-fit by using function f_rosa(x, p).

Question

The sol5 fit is bad:

why does the 5-parameter fit not converge even starting at the known solution.

Seems like it’s not an Optimization.jl issue but a MultistartOptimization.jl issue? Did you try other solvers? Did you try localizing it to MultistartOptimization.jl?