Struggling with creating a faster threaded version of a concurrent problem

Alright, I am so desperate that I call for some help from people experienced in concurrent/threaded programming in Julia since I am hitting a wall here. I can’t even provide a MWE but I am getting more and more desperate so I consider sitting down and spending a day or two to boil it down, however it’s really tricky with all those algorithms since I think that I am somehow standing right in the middle of “allocations+I/O vs. CPU”, since running with multiple threads is slower (5% and more) than with a single one. All in all, my efforts to use more than one CPU in a single process is just a big fail. The very first naive approach which used the evil threadid() mechanism for buffers was about 20% faster, but sometimes segfaulted (still without any noticeably wrong results ;) Anyways, the current approach uses a proper Channel implementation, without any gain in speed.

My hope is that either someone tosses me into the right direction or I some lights turn on when I summarise my situation (EDIT after writing this topic: nope, it did not happen :laughing: maybe it needs more time).

Ah well and forgot to mention: one of my biggest frustations at this point is that there is very similar single threaded C++ implementation, which runs 2-3 times faster. My initial thought was that Julia could be able to beat it, if not on a single thread, then on multiple ones :see_no_evil:

To sum up what the code does: for a triggered event in a neutrino detector, the hits on the PMTs are filtered via coincidence finder and clustering algorithms and then for a predefined set of directions (muon trajectory candidates, O(1000)), a fit is performed which consists of a rotation of all hits (using Rotations.jl and a few linalg operations like svd and some custom matrix inversions (it’s a linear problem and is quite fast).

Since each direction can be solved independently (after copying the initial event hits, since those need to be mutated further, like rotating them and sorting etc.), I thought that it’s a trivial parallel problem and I simply divide the directions into chunks and utilise Channels with isolated solvers (XYTSolver). Anyways, as written above, threading has no gain, only more overhead.

Maybe my problem is simply not suited for threads, or my data handling is totally suboptimal. I tried to reuse as many things as I can and reduced the number of allocations for a single event from 36000 to about 21000. The processing time on my M1 MacBook for a single event is for a reference event about 10ms, with 1500 direction fits, so about 6.6us per direction.
Within these 6.6us, about 100 hits are rotated, filtered, clusterised, an SVD and matrix inversion of a 3x3 matrix is performed and that’s basically it. That’s roughly 14 allocations per direction, mainly coming from resizing of the thread-associated buffers and things like SVD and triple matrix multiplications. I am not sure if I can reduce that further… Maybe the Channel approach introduces too much overhead compared to the amount of work done in a task?

julia> @benchmark muons = msfit($f.online.events[2])
BenchmarkTools.Trial: 512 samples with 1 evaluation.
 Range (min … max):  8.621 ms … 86.803 ms  ┊ GC (min … max): 0.00% … 88.71%
 Time  (median):     9.232 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   9.774 ms ±  6.349 ms  ┊ GC (mean ± σ):  5.71% ±  7.78%

                     ▁█▆▂
  ▃▂▂▁▁▂▁▄▅▃▃▃▅▅▆▅▅▄▄████▆▄▃▃▁▂▂▂▁▂▂▁▂▁▁▂▂▂▁▂▂▁▂▁▂▁▁▂▂▁▁▁▁▁▂ ▃
  8.62 ms        Histogram: frequency by time        10.4 ms <

 Memory estimate: 7.62 MiB, allocs estimate: 21597.

Running ProfileView does not reveal anything red, so I am pretty sure that the code is at least typesafe. Most of the time is spent in clusterize! and dense.jl / inv:

julia> @profview muons = msfit(f.online.events[2])

I know, it takes a lot of time and effort to understand foreign code and I am really sorry that I don’t have any MWE, but maybe someone spots some structural/design failure in my code, or simply confirms that this problem is not suitable for threading…

The repository where I currently develop this is here (it’s right now a little bit of a mess, since I am migrating many small projects into this larger one): src/scanfit.jl · muonscanfit · Tamas Gal / NeRCA.jl · GitLab

…and here are some of the relevant parts of the code.

Everything is about this HitR1 structure. I also tried to use a mutated version, but found that immutable works better (less allocations overall and faster runtime), even if I need to do lot’s of overwrites (using @set from @SetField), see below.

struct HitR1 <: AbstractReducedHit
    dom_id::Int32
    pos::Position{Float64}  # FieldVector{3, Float64} from StaticArrays.jl
    t::Float64
    tot::Float64
    n::Int
    weight::Float64
end

This is the function which is called for each event. The first part does some hit-filtering (reducing thousands of hits to a vector HitR1 with a high signal-probability). The second part passes the hits to the scanfit which performs the fit in thousands of directions.

"""
Performs a Muon track prefit for a given set of hits (usually snapshot hits).
"""
function (msf::MuonScanfit)(hits::Vector{T}) where T<:KM3io.AbstractHit
    rhits = msf.coincidencebuilder(HitR1, msf.detector, hits)
    sort!(rhits)
    unique!(h->h.dom_id, rhits)
    clusterize!(rhits, Match3B(msf.params.roadwidth, msf.params.tmaxlocal))

    # from here on, we have a Vector{HitR1} (rhits) which need to undergo some
    # mutations for each of the ~1500 directions in the `scanfit()` function
    candidates = scanfit(msf.params, rhits, msf.directionset)

    isempty(candidates) && return candidates
    sort!(candidates, by=m->m.Q; rev=true)
    candidates[1:msf.params.nfits]
end

This is the actual scanfit routine which prepares the Channels of XYTSolvers, structures which have their own internal buffers to copy the initial hits into and then do the linalg stuff to calculate the best-fit direction, position and time for the muon candidates.

"""
Performs the scanfit for each given direction and returns a
`Vector{MuonScanfitCandidate}` with all successful fits. The resulting vector can
be empty if none of the directions had enough hits to perform the algorithm.
"""
function scanfit(params::MuonScanfitParameters, rhits::Vector{T}, directionset::DirectionSet) where T<:AbstractReducedHit
    xytsolvers = Channel{XYTSolver}(Threads.nthreads())
    for _ in Threads.nthreads()
        put!(xytsolvers, XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params.σ))
    end
    chunk_size = max(1, length(directionset.directions) ÷ Threads.nthreads())
    chunks = Iterators.partition(directionset.directions, chunk_size)

    tasks = map(chunks) do chunk
        Threads.@spawn begin
            xytsolver = take!(xytsolvers)
            results = map(c -> xytsolver(rhits, c, directionset.angular_separation), chunk)
            put!(xytsolvers, xytsolver)
            results
        end
    end
    collect(Iterators.flatten(fetch.(tasks)))
end

The XYTSolver itself has som internal buffers for the hits, the covariance matrix and the time-residuals vector, which are resized!() during the solving procedure for each direction. Here is the definition with the inner constructor:


"""
A task worker whichs solves for x, y an t for a given set of hits and a direction.
"""
struct XYTSolver
    hits_buffer::Vector{HitR1}
    covmatrix::CovMatrix
    timeresvec::Vector{Float64}
    nmaxhits::Int
    matcher::Match1D
    est::Line1ZEstimator

    function XYTSolver(nmaxhits::Int, roadwidth::Float64, tmaxlocal::Float64, σ::Float64)
        new(Vector{HitR1}(), CovMatrix(nmaxhits, σ), Vector{Float64}(), nmaxhits, Match1D(roadwidth, tmaxlocal),
            Line1ZEstimator(Line1Z(Position(0, 0, 0), 0))
        )
    end
end

and last but not least, the function which performs the calculation for each of the ~1500 directions. This is the part which can be executed in parallel. The covmatrix and timeresvec fields are simple matrices and vectors.

As you see below, the hits are rotated and copied over to the s.hits_buffer and then resized/mutated afterwards (either directly via resize! or indirectly via clusterize!, which shrinks the vector). I tried to use views instead of resizing the vector but I got 30% more allocations and it was twice as slow compared to the resize approach.

function (s::XYTSolver)(hits::Vector{T}, dir::Direction{Float64}, α::Float64) where T<:AbstractReducedHit
    χ² = Inf
    R = rotator(dir)
    n_initial_hits = length(hits)
    resize!(s.hits_buffer, n_initial_hits)

    for (idx, hit) ∈ enumerate(hits) # rotate hits
        s.hits_buffer[idx] = @set hit.pos = R * hit.pos
    end

    if n_initial_hits > s.nmaxhits
        sort!(s.hits_buffer; by=timetoz, alg=PartialQuickSort(s.nmaxhits))
        resize!(s.hits_buffer, s.nmaxhits)
    end

    clusterize!(s.hits_buffer, s.matcher)

    hits = s.hits_buffer  # just for convenience
    n_final_hits = length(hits)

    n_final_hits <= s.est.NUMBER_OF_PARAMETERS && return MuonScanfitCandidate(Position(0, 0, 0), dir, 0, -Inf, 0)

    NDF = n_final_hits - s.est.NUMBER_OF_PARAMETERS
    N = hitcount(hits)
    sort!(hits)

    try
        estimate!(s.est, hits)
    catch ex
        # isa(ex, SingularSVDException) && @warn "Singular SVD"
        return MuonScanfitCandidate(Position(0, 0, 0), dir, 0, -Inf, 0)
    end

    update!(s.covmatrix, s.est.model.pos, hits, α)
    V = view(s.covmatrix.M, 1:n_final_hits, 1:n_final_hits)

    n_final_hits > length(s.timeresvec) && resize!(s.timeresvec, n_final_hits)
    timeresvec!(s.timeresvec, s.est.model, hits)

    V⁻¹ = inv(V)
    Y = view(s.timeresvec, 1:n_final_hits)  # only take the relevant part of the buffer
    χ² = transpose(Y) * V⁻¹ * Y
    fit_pos = R \ s.est.model.pos

    MuonScanfitCandidate(fit_pos, dir, s.est.model.t, quality(χ², N, NDF), NDF)
end

Really, any kind of help is appreciated.

Sorry it’s late for me and I am on mobile only so I only have 2 small comments to make:

  1. You mentioned that you have lots of small matrices and it seems you don’t use StaticArrays.jl. This can improve speed/allocations for small arrays (~< 100 i think) by quite a lot.
  2. I don’t understand why you are using a Channel in the function scanfit. You could just partition the inputs and then spawn a Task for each chunk, that allocates the solver and use it to solve all problems of that chunk.

Maybe I’ll find some time tomorrow to have a more detailed look at your code.

2 Likes

I was not clear enough I guess, but it’s no wonder, given that I don’t have a nice MWE :wink:

I do have small matrices but only one for each direction and that is being mutated as well. I am already using MMatrix from StaticArrays for that:


mutable struct Line1ZEstimator
    model::Line1Z
    V::MMatrix{3, 3, Float64, 9}
    NUMBER_OF_PARAMETERS::Int
    MINIMAL_SVD_WEIGHT::Float64
    function Line1ZEstimator(model::Line1Z)
        V = zero(MMatrix{3, 3, Float64, 9})
        new(model, V, 3, 1.0e-4)
    end
end

Regarding the Channel: you are probably right, I am still discovering the “proper way” of threading in Julia and I am not really familiar with channels/tasks yet. The channel approach was the one which someone recommended in another thread on a somewhat similar problem.
However, I still get segfaults on an M1 mac sometimes, just like for the threadid() approach in the early naive multi-threaded versions of the code (which were a little bit faster).

Anyways, I’ll try to spawn tasks instead of using Channels and report back. Thanks for your time @abraemer !

When trying multithreading remember to set BLAS.set_num_threads(1) (at least when using OpenBLas - the default). more info on BLAS threads

I just noticed that you only ever create a single instnace of XYTSolver in that function because you wrote for _ in Threads.nthreads() and not for _ in 1:Threads.nthreads(). This explains why multithreading made things worse for you.

Anyways, I would imagine the function something like this:

"""
Performs the scanfit for each given direction and returns a
`Vector{MuonScanfitCandidate}` with all successful fits. The resulting vector can
be empty if none of the directions had enough hits to perform the algorithm.
"""
function scanfit(params::MuonScanfitParameters, rhits::Vector{T}, directionset::DirectionSet) where T<:AbstractReducedHit
    chunk_size = max(1, length(directionset.directions) ÷ Threads.nthreads())
    chunks = Iterators.partition(directionset.directions, chunk_size)

    tasks = map(chunks) do chunk
        Threads.@spawn begin
           # create solver here
            xytsolver = XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params.σ)
            results = [xytsolver(rhits, c, directionset.angular_separation) for c in chunk]
            results
        end
    end
    # I think Iterators.flatten tends to vause inference problems and I prefer mapreduce but might just be preference of style
    return mapreduce(fetch, vcat, tasks) # maybe mixed up the order of fetch and vcat
end
2 Likes

Holy… I cannot believe that I overlooked that for _ in ... typo, this is unbelievable. I spent a full week on benchmarking allocations and whatnot, theorised with other (non-Julian experts but experienced programmers) about allocation overhead and memory I/O bottlenecks, context switching etc… oh dear. Sometimes I hate the fact that numbers are iterable :wink:

Alright, I added 1: (I have not tried the bare-task approach yet, but I’ll do) and here are the results:

single thread (a sample file with about 9000 events):

░ tamasgal@silentbox:NeRCA.jl  muonscanfit ●● ஃ v1.9.3 took 5s
░ 07:40:19 > julia --threads=1 --project=. scripts/muonscanfit.jl -a KM3NeT_00000133_20221025.detx -i /Volumes/Hispeed/Experiments/KM3NeT/ARCA21v8.1/mcv8.1.gsg_numu-CCHEDIS_1e2-1e8GeV.sirene.jterbr00013339.47.root -o muonscanfit.h5
  Activating project at `~/Dev/NeRCA.jl`
Loading libraries...
Progress: 100%|████████████████████████████████████████████| Time: 0:04:51
Total number of events processed: 8737
Number of muon candidates: 104844
Failed reconstructions: 0

with 4 threads

░ tamasgal@silentbox:NeRCA.jl  muonscanfit ●● ஃ v1.9.3 took 5m 22s
░ 07:46:06 > julia --threads=4 --project=. scripts/muonscanfit.jl -a KM3NeT_00000133_20221025.detx -i /Volumes/Hispeed/Experiments/KM3NeT/ARCA21v8.1/mcv8.1.gsg_numu-CCHEDIS_1e2-1e8GeV.sirene.jterbr00013339.47.root -o muonscanfit.h5
  Activating project at `~/Dev/NeRCA.jl`
Loading libraries...
Progress: 100%|████████████████████████████████████████████| Time: 0:04:11
Total number of events processed: 8737
Number of muon candidates: 104844
Failed reconstructions: 0

[36603] signal (11.2): Segmentation fault: 11
in expression starting at none:0

The gain is very little, but at least I am back to something noticeable, instead of being slower by 5%.

I still have the issue with the segmentation fault and I am not sure if it’s my fault or one of the remaining multithreading bugs in Julia. I found a few related issues on GitHub but those were fixed. I guess I will need to open another one.

Yes, I have read about that issue and set that to 1 already, but thanks for the reminder!

Indeed, in the current code I have:

return collect(Iterators.flatten(fetch.(tasks)))::Vector{MuonScanfitCandidate}

but your mapreduce is much nicer. I also removed the Channel thing, but the difference in the overall performance is negligible (about 5sec in total time, but only measured once).

So the only remaining questions are about the Segfault and of course: is there anything I could do to get a bit more out of threading.

Btw. I have never seen a segfault on an Intel Xeon machine with this code, only on my M1 mac, so I guess it’s a bug in Julia.

1 Like

Well, we need a minimal working example (MWE) for that…

Yes, for that I’ll definitely create one :wink:

Did you check how much time the garbage collector needs? With Julia 1.9 often this is the bottleneck for multi-threaded programs (might be better with 1.10.) If it is the bottleneck better use multi-processing instead.

And don’t forget: Before going multi-threaded reduce the allocations as much as possible and try to fully use SIMD, e.g. by using GitHub - JuliaSIMD/LoopVectorization.jl: Macro(s) for vectorizing loops.

1 Like

I reduced the total number of allocation to around 14 per parallel task, I could not manage to go lower.

GC time is 0.00% in median and about 6.7% in mean, but I’ll try with 1.10 as well.

I already looked at SIMD optimisations but could not improve anything. I’ll try harder…

So I don’t think that without something runnable, people here are able to help you (well except pointing at random things and making general statements). For starters it does not need to be MWE just easy to setup and run in some form.

Concerning multithreading: I think your workloads are quite short right now, so the overhead of threading (spawning a Task and scheduling it) might outweigh the speedup. I found this old thread where the overhead of Threads.@threads is O(10µs). To get good speedups in short sections there are different libraries such as Polyester.jl. But I personally think we should try to optimize the single threaded case first to get back at the speed of your C++ implementation.

Of course finding the segfault would be very interesting! A quick and easy check would be removing all @inbounds from your code base. If the segfault still occurs, then it really might be a bug in Julia which will be harder to isolate.

2 Likes

Thanks, I already tried to utilise Polyester.jl but did not manage to fit it in. I will reinvestigate that either.

So if anyone is eager to run the code, I prepared something which hopefully has a low threshold. It requires adding our registry since two dependencies are not official Julia registry.

Setup:

julia> using Pkg; Pkg.Registry.add(RegistrySpec(url = "https://git.km3net.de/common/julia-registry"))

and then clone the Git repo and check out the muonscanfit-threading-demo branch:

git clone https://git.km3net.de/tgal/NeRCA.jl.git
cd NeRCA.jl
git checkout muonscanfit-threading-demo
julia --project=. -e "using Pkg; Pkg.instantiate()"

Here is an example how to reconstruct a single event, after loading the event file and the detector calibration:


julia> using NeRCA, KM3io, KM3NeTTestData
[ Info: Precompiling NeRCA [89b7c20c-a96a-11e9-35df-35fba0891eb2]

julia> f = ROOTFile(datapath("online", "muon_cc_daqevents.root"))
ROOTFile{OnlineTree (250 events, 0 summaryslices), OfflineTree (0 events)}

julia> det = Detector(datapath("detx", "KM3NeT_00000133_20221025.detx"))
Detector 133 (v5) with 21 strings and 399 modules.

julia> msfit = MuonScanfit(Detector(datapath("detx", "KM3NeT_00000133_20221025.detx")))
MuonScanfit with a coarse scan of 7.0ᵒ and a fine scan of 0.5ᵒ.

julia> raw_hits = f.online.events[2].snapshot_hits;  # there are 250 events, I use the second one for reference

julia> @benchmark muons = msfit($raw_hits)
[ Info: Loading BenchmarkTools ...
BenchmarkTools.Trial: 627 samples with 1 evaluation.
 Range (min … max):  7.525 ms …  14.623 ms  ┊ GC (min … max): 0.00% … 41.17%
 Time  (median):     7.723 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.976 ms ± 766.122 μs  ┊ GC (mean ± σ):  2.65% ±  6.60%

  ▆█▆▇▆▄▁▂▂▄▂▂▂
  █████████████▇▇▅▁▄▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▇▇▄▇▇▅▅▄▇▇▇▄▅▅▅▅▄▄ █
  7.52 ms      Histogram: log(frequency) by time      10.6 ms <

 Memory estimate: 3.93 MiB, allocs estimate: 11550.

Its late, but it seems that I got some speedup using this version of scanfit:

using ChunkSplitters
function scanfit(
    params::MuonScanfitParameters,
    rhits::Vector{T},
    directionset::DirectionSet;
    nchunks = Threads.nthreads()
) where T<:AbstractReducedHit
    results = [ NeRCA.MuonScanfitCandidate[] for _ in 1:nchunks ]
    Threads.@sync for (irange, ichunk) in chunks(directionset.directions, nchunks)
        Threads.@spawn for i in irange
            xytsolver = XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params.σ)
            c = directionset.directions[i]
            push!(results[ichunk], xytsolver(rhits, c, directionset.angular_separation))
        end
    end
    return vcat(results...)
end

A possible reason for the speedup is that in the original version the type of results was not being inferred (that can be checked with @code_warntype msfit(raw_hits). There I’ve explicitly initialized the results array as a vector of NeRCA.MuonScanfitCandidate vectors.

With the previous version I got:

julia> @benchmark muons = msfit($raw_hits)
BenchmarkTools.Trial: 352 samples with 1 evaluation.
 Range (min … max):   9.704 ms … 82.258 ms  ┊ GC (min … max):  0.00% … 82.68%
 Time  (median):     13.118 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   14.256 ms ±  6.557 ms  ┊ GC (mean ± σ):  10.33% ± 15.46%

  ▇▅▅▂▄▄▅▇█▆▄▅▂                                                
  ██████████████▇▁▄▄▄▆▁▁▁▁▁▁▁▁▁▁▁▁▄▆▁▁▁▁▁▁▁▄▄▁▄▆▁▄▇▄▁▁▁▄▄▄▄█▄ ▇
  9.7 ms       Histogram: log(frequency) by time      35.6 ms <

 Memory estimate: 16.22 MiB, allocs estimate: 37800.

and with the new one:

julia> @benchmark muons = msfit($raw_hits)
BenchmarkTools.Trial: 524 samples with 1 evaluation.
 Range (min … max):  6.389 ms … 26.336 ms  ┊ GC (min … max):  0.00% … 54.47%
 Time  (median):     8.015 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   9.533 ms ±  3.711 ms  ┊ GC (mean ± σ):  14.63% ± 18.76%

   ▃▃▃▄▅█▅▅▄▃▂▂                                      ▂        
  ▆█████████████▇▁▅▄▆▅▁▁▅▁▄▁▄▄▁▁▁▁▁▁▁▁▁▁▁▄▁▄▁▁▅▆▆▇▇█▇███▇▄▅▄ ▇
  6.39 ms      Histogram: log(frequency) by time     19.8 ms <

 Memory estimate: 38.96 MiB, allocs estimate: 47203.

Although I’m not completely sure the results are equivalent. I’m using 4 threads here.

As a side note (but that didn’t change performance here) you have a try ... catch inside one of the inner functions, and that can cause performance problems. I think it would be better if you instead of throwing an error, returned nothing from the inner function and handled with the error at a higher level.

2 Likes

Thanks for the runnable example :slight_smile:

Baseline

  • Essentially all time is spent inside XYTSolver (~90%)
  • Low hanging fruit: replace χ² = transpose(Y) * V⁻¹ * Y with χ² = dot(Y, V⁻¹, Y) to save an allocation and slightly more speed
  • Now:
    • clusterize!(s.hits_buffer, s.matcher) takes 50% of the total time,
    • estimate!(s.est, hits) takes 22%
    • V⁻¹ = inv(V) takes 10%

In reverse order:

V⁻¹ = inv(V)

  • Don’t just compute the inverse matrix and then apply it to a single vector. Better do “division” like V⁻¹ * Y == V \ Y, so we have χ² = dot(Y, V \ Y)
    7.5% speedup

estimate

Almost all time is spent inside invert!, StaticArrays.jl has no specialized svd but there is a specialized eigen so use that instead. Then instead of multiply the matrices to get the pseudoinverse, just return the individual components, because we can do the application to vector with just matrix-vector products. Also I replaced your y₀,y₁,y₂ with yvec = zeros(MVector{3}), this also simplifies the construction of the model.

clusterize!

This is I find a bit trickier. Maybe there are algorithmic optimizations, but I saw nothing obvious. Most time is spent inside the match as there are quadratically many calls. I found that calling time inside the match functions was the bottleneck, so I precompute all the times and pass the time corresponding to the hit also into match. That got me ~60% speedup
Another low-hanging fruit is to reuse the Clique inside XYTSolver to save some allocations and gain a very minor speedup.

I think @Imiq saw a typeinstability introduced by Iterators.flatten. I tried Imiq’s version and is was a lot slower than my suggested version with mapreduce (and needs another dependency).

Summary

All in all I achieved a speedup of 2.3x. clusterize! still takes ~60% of the time (with 2/3s of the time spent inside match), the V\Y takes ~15%, and estimate!´/invert2!` also takes 15%.

hits.jl
struct Hit <: KM3io.AbstractHit
    t::Float64
    tot::Float64
end
Base.isless(lhs::Hit, rhs::Hit) = lhs.t < rhs.t

struct HitL0 <: KM3io.AbstractHit
    channel_id::Int8
    t::Float64
    tot::Float64
    pos::Position{Float64}
    dir::Direction{Float64}
end
Base.isless(lhs::HitL0, rhs::HitL0) = time(lhs) < time(rhs)

abstract type AbstractSpecialHit <: KM3io.AbstractHit end  # TODO: bad naming ;)
abstract type AbstractCoincidenceHit <: AbstractSpecialHit end
abstract type AbstractReducedHit <: AbstractSpecialHit end

struct HitL1 <: AbstractCoincidenceHit
    dom_id::Int32
    hits::Vector{HitL0}
end
# function Base.time(h::HitL1)
#     # n = length(h)
#     # n > length(SLEWS_L1) && return SLEWS_L1[end]
#     # time(first(h.hits)) - SLEWS_L1[n-1]
#     time(first(h.hits))
# end
# position(h::HitL1) = first(h.hits).pos
# function tot(h::HitL1)
#     combined_hit = combine(h.hits)
#     combined_hit.tot
# end

struct HitL2 <: AbstractCoincidenceHit
    dom_id::Int32
    hits::Vector{HitL0}
end
HitL1(m::DetectorModule, hits) = HitL1(m.id, hits)
HitL2(m::DetectorModule, hits) = HitL2(m.id, hits)
Base.length(c::AbstractCoincidenceHit) = length(c.hits)
Base.eltype(::AbstractCoincidenceHit) = HitL0
function Base.iterate(c::AbstractCoincidenceHit, state=1)
    @inbounds state > length(c) ? nothing : (c.hits[state], state+1)
end

struct HitR0 <: AbstractReducedHit
    t::Float64
    tot::Float64
    channel_id::Int8
end
Base.isless(lhs::HitR0, rhs::HitR0) = time(lhs) < time(rhs)

struct HitR1 <: AbstractReducedHit
    dom_id::Int32
    pos::Position{Float64}
    t::Float64
    tot::Float64
    n::Int
    weight::Float64
end
Base.isless(lhs::HitR1, rhs::HitR1) = lhs.dom_id == rhs.dom_id ? time(lhs) < time(rhs) : lhs.dom_id < rhs.dom_id
const HitR2 = HitR1
function HitR1(dom_id::Integer, hits::Vector{HitL0})
    combined_hit = combine(hits)
    h = first(hits)
    count = weight = length(hits)
    HitR1(dom_id, h.pos, combined_hit.t, combined_hit.tot, count, weight)
end
# function HitR1(m::DetectorModule, hit::HitL1)
#     count = weight = length(hit)
#     HitR1(m.id, position(hit), time(hit), tot(hit), count, weight)
# end
weight(h::HitR1) = h.weight

starttime(hit) = time(hit)
endtime(hit) = time(hit) + hit.tot

"""
Return the total number of hits for a collection of reduced hits.
"""
hitcount(hits::AbstractArray{T}) where T<:AbstractReducedHit = sum(h.n for h in hits)

"""
Combine snapshot and triggered hits to a single hits-vector.

This should be used to transfer the trigger information to the
snapshot hits from a DAQEvent. The triggered hits are a subset
of the snapshot hits.

"""
function combine(snapshot_hits::Vector{KM3io.SnapshotHit}, triggered_hits::Vector{KM3io.TriggeredHit})
    triggermasks = Dict{Tuple{UInt8, Int32, Int32, UInt8}, Int64}()
    for hit ∈ triggered_hits
        triggermasks[(hit.channel_id, hit.dom_id, hit.t, hit.tot)] = hit.trigger_mask
    end
    n = length(snapshot_hits)
    hits = sizehint!(Vector{TriggeredHit}(), n)
    for hit in snapshot_hits
        channel_id = hit.channel_id
        dom_id = hit.dom_id
        t = hit.t
        tot = hit.tot
        triggermask = get(triggermasks, (channel_id, dom_id, t, tot), 0)
        push!(hits, TriggeredHit(dom_id, channel_id, t, tot, triggermask))
    end
    hits
end


"""
Create a `Vector` with hits contributing to `n`-fold coincidences within a time
window of Δt.
"""
function nfoldhits(hits::Vector{T}, Δt, n) where {T<:KM3io.AbstractDAQHit}
    hit_map = modulemap(hits)
    chits = Vector{T}()
    for (dom_id, dom_hits) ∈ hit_map
        bag = Vector{T}()
        push!(bag, dom_hits[1])
        t0 = dom_hits[1].t
        for hit in dom_hits[2:end]
            if hit.t - t0 > Δt
                if length(bag) >= n
                    append!(chits, bag)
                end
                bag = Vector{T}()
            end
            push!(bag, hit)
            t0 = hit.t
        end
    end
    return chits
end


"""
Calculate the multiplicities for a given time window. Two arrays are
are returned, one contains the multiplicities, the second one the IDs
of the coincidence groups.
The hits should be sorted by time and then by dom_id.
"""
function count_multiplicities(hits::Vector{T}, tmax=20) where {T<:KM3io.AbstractHit}
    n = length(hits)
    mtp = ones(Int32, n)
    cid = zeros(Int32, n)
    idx0 = 1
    _mtp = 1
    _cid = idx0
    t0 = hits[idx0].t
    dom_id = hits[idx0].dom_id
    for i in 2:n
        hit = hits[i]
        if hit.dom_id != dom_id
            mtp[idx0:i-1] .= _mtp
            cid[idx0:i-1] .= _cid
            dom_id = hit.dom_id
            t0 = hit.t
            _mtp = 1
            _cid += 1
            idx0 = i
            continue
        end
        Δt = hit.t - t0
        if Δt > tmax
            mtp[idx0:i] .= _mtp
            cid[idx0:i] .= _cid
            _mtp = 0
            _cid += 1
            idx0 = i
            t0 = hit.t
        end
        _mtp += 1
    end
    mtp[idx0:end] .= _mtp
    cid[idx0:end] .= _cid
    mtp, cid
end

"""
Counts the multiplicities and modifies the .multiplicity field of the hits.
Important: the hits have to be sorted by time and then by DOM ID first.
"""
function count_multiplicities!(hits::Vector{KM3io.XCalibratedHit}, tmax=20)
    _mtp = 0
    _cid = 0
    t0 = 0
    dom_id = 0
    hit_buffer = Vector{XCalibratedHit}()

    function process_buffer()
        while !isempty(hit_buffer)
            _hit = pop!(hit_buffer)
            _hit.multiplicity.count = _mtp
            _hit.multiplicity.id = _cid
        end
    end

    function reset()
        _mtp = 1
        _cid += 1
    end

    for hit ∈ hits
        if length(hit_buffer) == 0
            reset()
            push!(hit_buffer, hit)
            t0 = hit.t
            dom_id = hit.dom_id
            continue
        end
        if hit.dom_id != dom_id
            process_buffer()
            push!(hit_buffer, hit)
            t0 = hit.t
            dom_id = hit.dom_id
            reset()
            continue
        end
        Δt = hit.t - t0
        if Δt > tmax
            process_buffer()
            push!(hit_buffer, hit)
            t0 = hit.t
            reset()
        else
            push!(hit_buffer, hit)
            _mtp += 1
        end
    end
    if length(hit_buffer) > 0
        process_buffer()
    end
    return
end


"""
Categorise hits by DU and put them into a dictionary of DU=>Vector{Hit}.

Caveat: this function is not typesafe, only suited for high-level analysis (like plots).
"""
@inline duhits(hits::Vector{T}) where {T<:KM3io.XCalibratedHit} = categorize(:du, hits)


"""
Return a vector of hits with ToT >= `tot`.
"""
function totcut(hits::Vector{T}, tot) where {T<:KM3io.AbstractDAQHit}
    return filter(h->h.tot >= tot, hits)
end


"""
Returns the estimated number of photoelectrons for a given ToT.
"""
function nphes(tot)
    if tot <= 20
        return 0.0
    end
    if tot <= 26
        return 1.0
    end
    if tot < 170
        return 1.0 + (tot - 26)/(1/0.28)
    end
    return 40.0 + (255 - tot)*2.0
end


"""

Creates a map (`Dict{Int32, Vector{T}}`) from a flat `Vector{T}` split up based
on the `dom_id` of each element. A typical use is to split up a vector of hits
by their optical module IDs.

This function is similar to `categorize(:dom_id, Vector{T})` but this method
is completely typesafe.

"""
function modulemap(hits::Vector{T}) where T
    out = Dict{Int32, Vector{T}}()
    for hit ∈ hits
        if !(hit.dom_id ∈ keys(out))
            out[hit.dom_id] = T[]
        end
        push!(out[hit.dom_id], hit)
    end
    out
end

"""
Calibrates hits.
"""
function KM3io.calibrate(T::Type{HitR1}, det::Detector, hits)
    rhits = sizehint!(Vector{T}(), length(hits))
    for hit ∈ hits
        pmt = det[hit.dom_id][hit.channel_id]
        t = hit.t + pmt.t₀
        push!(rhits, T(hit.dom_id, pmt.pos, t, hit.tot, 1, 0))
    end
    rhits
end
function KM3io.calibrate(T::Type{HitL0}, m::DetectorModule, hits)
    chits = sizehint!(Vector{T}(), length(hits))
    for hit ∈ hits
        pmt = m[hit.channel_id]
        t = hit.t + pmt.t₀
        push!(chits, T(hit.channel_id, t, hit.tot, pmt.pos, pmt.dir))
    end
    chits
end


"""
Combines several hits into a single one by taking the earliest start time,
then latest endtime and a ToT which spans over the time range.
"""
function combine(hits::Vector{T}) where T <: KM3io.AbstractHit
    isempty(hits) && return Hit(0.0, 0.0)

    t1 = starttime(first(hits))
    t2 = endtime(first(hits))

    @inbounds for hit ∈ hits[2:end]
        _t1 = starttime(hit)
        _t2 = endtime(hit)
        if t1 > _t1
            t1 = _t1
        end
        if t2 < _t2
            t2 = _t2
        end
    end
    Hit(t1, t2 - t1)
end


struct L1BuilderParameters
    Δt::Float64
    combine::Bool
end

struct L1Builder
    params::L1BuilderParameters
end

"""
Find coincidences within the time window `Δt` of the initialised `params`. The return
value is a vector of `L1Hit`s.
"""
function (b::L1Builder)(::Type{H}, det::Detector, hits::Vector{T}; combine=true) where {T, H}
    out = H[]
    mm = modulemap(hits)
    for (m, module_hits) ∈ mm
        _findL1!(out, det[m], module_hits, b.params.Δt, combine)
    end
    out
end
(b::L1Builder)(det::Detector, hits) = b(HitL1, det, hits)
function _findL1!(out::Vector{H}, m::DetectorModule, hits, Δt, combine::Bool) where H <: AbstractSpecialHit
    n = length(hits)
    n < 2 && return out

    chits = calibrate(HitL0, m, hits)
    sort!(chits)

    ref_idx = 1  # starting with the first hit, obviously
    idx = 2      # first comparison is the second hit
    while ref_idx <= n
        restart = false
        idx = ref_idx + 1
        while(idx <= n+1 || restart) # n+1 to do the final loop after the last hit
            if idx > n || (time(chits[idx]) - time(chits[ref_idx]) > Δt)
                end_idx = idx - 1
                # check if we have gathered some hits
                if ref_idx != end_idx
                    coincident_hits = [chits[i] for i ∈ ref_idx:end_idx]
                    # push!(out, H(m, HitL1(m.id, coincident_hits)))
                    push!(out, H(m.id, coincident_hits))
                    if combine
                        ref_idx = end_idx
                    end
                end
                restart = true
                break
            end
            idx += 1
        end
        ref_idx += 1
    end
    out
end


struct L2BuilderParameters
    n_hits::Int
    Δt::Float64
    ctmin::Float64
end

struct L2Builder
    params::L2BuilderParameters
end

function (b::L2Builder)(hits)
    out = []
    for hit ∈ hits
    end

    error("Not implemented yet")
end


const SLEWS_L1 = SVector(
    +0.00, +0.39, +0.21, -0.59, -1.15,
    -1.59, -1.97, -2.30, -2.56, -2.89,
    -3.12, -3.24, -3.56, -3.69, -4.00,
    -4.10, -4.16, -4.49, -4.71, -4.77,
    -4.81, -4.87, -4.88, -4.83, -5.21,
    -5.06, -5.27, -5.18, -5.24, -5.79,
    -6.78, -6.24
)

function KM3io.slew(h::HitR1)
    h.n > length(SLEWS_L1) && return SLEWS_L1[end]
    SLEWS_L1[h.n]
end


"""
Calculates the time to reach the z-position of the `hit` along the z-axis.
"""
function timetoz(hit)
    time(hit) * KM3io.Constants.C - hit.pos.z
end


abstract type AbstractMatcher end

"""
3D match criterion with road width, intended for muon signals.

Origin: B. Bakker, "Trigger studies for the Antares and KM3NeT detector.",
Master thesis, University of Amsterdam. With modifications from Jpp (M. de Jong).
"""
mutable struct Match3B <: AbstractMatcher
    const roadwidth::Float64
    const tmaxextra::Float64
    x::Float64
    y::Float64
    z::Float64
    d::Float64
    t::Float64
    dmin::Float64
    dmax::Float64
    d₂::Float64

    const D0::Float64
    const D1::Float64
    const D2::Float64
    const D02::Float64
    const D12::Float64
    const D22::Float64
    const Rs2::Float64
    const Rst::Float64
    const Rt::Float64
    const R2::Float64

    function Match3B(roadwidth, tmaxextra=0.0)
        tt2 = KM3io.Constants.TAN_THETA_C_WATER^2

        D0 = roadwidth
        D1 = roadwidth * 2.0
        # calculation D2 in thesis is wrong, here correct (taken from Jpp/JMatch3B)
        D2 = roadwidth * 0.5 * sqrt(tt2 + 10.0 + 9.0 / tt2)

        D02 = D0 * D0
        D12 = D1 * D1
        D22 = D2 * D2

        R = roadwidth
        Rs = R * KM3io.Constants.SIN_THETA_C_WATER

        R2 = R * R;
        Rs2 = Rs * Rs;
        Rst = Rs * KM3io.Constants.TAN_THETA_C_WATER
        Rt = R * KM3io.Constants.TAN_THETA_C_WATER

        new(
            roadwidth, tmaxextra, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
            D0, D1, D2, D02, D12, D22, Rs2, Rst, Rt, R2,
        )
    end
end
Base.show(io::IO, m::Match3B) = print(io, "Match3B($(m.roadwidth), $(m.tmaxextra))")

function (m::Match3B)(hit1, hit2, time1, time2)
      m.x = hit1.pos.x - hit2.pos.x
      m.y = hit1.pos.y - hit2.pos.y
      m.z = hit1.pos.z - hit2.pos.z
      m.d₂ = m.x * m.x + m.y * m.y + m.z * m.z
      m.t = abs(time1 - time2)

      if (m.d₂ < m.D02)
        m.dmax = √m.d₂ * KM3io.Constants.INDEX_OF_REFRACTION_WATER
      else
        m.dmax = √(m.d₂ - m.Rs2) + m.Rst
      end

      m.t > m.dmax * KM3io.Constants.C_INVERSE + m.tmaxextra && return false

      if m.d₂ > m.D22
        m.dmin = √(m.d₂ - m.R2) - m.Rt
      elseif m.d₂ > m.D12
        m.dmin = √(m.d₂ - m.D12)
      else
        return true
      end

      m.t >= m.dmin * KM3io.Constants.C_INVERSE - m.tmaxextra
end

"""
Simple Cherenkov matcher for muon signals. The muon is assumed to travel parallel
to the Z-axis.
"""
mutable struct Match1D <: AbstractMatcher
    const roadwidth::Float64  # maximal road width [ns]
    const tmaxextra::Float64  # maximal extra time [ns]
    const tmax::Float64
    x::Float64
    y::Float64
    z::Float64
    d::Float64
    t::Float64

    function Match1D(roadwidth, tmaxextra=0.0)
        tmax = 0.5 * roadwidth * KM3io.Constants.TAN_THETA_C_WATER * KM3io.Constants.C_INVERSE  +  tmaxextra
        new(roadwidth, tmaxextra, tmax, 0.0, 0.0, 0.0, 0.0, 0.0)
    end
end

function (m::Match1D)(hit1, hit2, time1, time2)

      m.z = hit1.pos.z - hit2.pos.z
      #m.t = abs(time(hit1) - time(hit2) - m.z * KM3io.Constants.C_INVERSE)
      m.t = abs(time1 - time2 - m.z * KM3io.Constants.C_INVERSE)

      m.t > m.tmax && return false

      x = hit1.pos.x - hit2.pos.x
      y = hit1.pos.y - hit2.pos.y
      d = √(x^2 + y^2)

      if d <= m.roadwidth/2
          return m.t <=  d  * KM3io.Constants.TAN_THETA_C_WATER * KM3io.Constants.C_INVERSE  +  m.tmaxextra
      elseif d <= m.roadwidth
          return m.t <= (m.roadwidth - d) * KM3io.Constants.TAN_THETA_C_WATER * KM3io.Constants.C_INVERSE  +  m.tmaxextra
      end

      false
end

@inline function swap!(arr::AbstractArray, i::Int, j::Int)
    arr[i], arr[j] = arr[j], arr[i]
    arr
end

"""
Clique clusterizer which takes a matcher algorithm like `Match3B` as input.
"""
struct Clique{T<:AbstractMatcher}
    match::T
    weights::Vector{Float64}
    times::Vector{Float64}
    Clique(m::T) where T = new{T}(m, Float64[], Float64[])
end

"""
Applies the clique clusterization algorithm and leaves only the best matching
hits in the input array.
"""
clusterize!(hits::AbstractArray{T}, m::AbstractMatcher) where T<:AbstractSpecialHit = clusterize!(hits, Clique(m))
function clusterize!(hits::AbstractArray{T}, c::Clique) where T<:AbstractSpecialHit
    N = length(hits)
    N == 0 && return hits
    times = c.times

    resize!(c.weights, N)
    resize!(c.times, N)

    @inbounds for i ∈ 1:N
        c.weights[i] = weight(hits[i])
        times[i] = time(hits[i])
    end

    @inbounds for i ∈ 1:N
        @inbounds for j ∈ i:N
            j == i && continue
            if c.match(hits[i], hits[j], times[i], times[j])
                c.weights[i] += weight(hits[j])
                c.weights[j] += weight(hits[i])
            end
        end
    end

    # Remove hit with the smallest weight of associated hits.
    # This procedure stops when the weight of associated hits
    # is equal to the maximal weight of (remaining) hits.
    n = N
    # @show N
    @inbounds while true
        j = 1
        W = c.weights[j]
        # @show W

        @inbounds for i ∈ 2:n
            if c.weights[i] < c.weights[j]
                j = i
            elseif c.weights[i] > W
                W = c.weights[i]
            end
        end
        # end condition
        c.weights[j] == W && return resize!(hits, n)

        # Swap the selected hit to end.
        swap!(hits, j, n)
        swap!(c.weights, j, n)
        swap!(times, j, n)

        # Decrease weight of associated hits for each associated hit.
        @inbounds for i ∈ 1:n
            c.weights[n] <= weight(hits[n]) && break
            if c.match(hits[i], hits[n], times[i], times[n])
                c.weights[i] -= weight(hits[n])
                c.weights[n] -= weight(hits[i])
            end
        end

        n -= 1
    end
end
scanfit.jl
Base.@kwdef struct MuonScanfitParameters
    tmaxlocal::Float64 = 18.0  # [ns]
    roadwidth::Float64 = 200.0  # [m]
    nmaxhits::Int = 50  # maximum number of hits to use
    nfits::Int = 1
    nprefits::Int = 10
    σ::Float64 = 5.0  # [ns]
    α₁::Float64 = 7.0  # grid angle of the coarse scan
    α₂::Float64 = 0.5  # grid angle of the fine scan
    θ::Float64 = 3.5  # opening angle of the fine-scan cone
end


"""

A container of directions with additionial information about their median
angular separation.

"""
struct DirectionSet
    directions::Vector{Direction{Float64}}
    angular_separation::Float64
end

struct MuonScanfit
    params::MuonScanfitParameters
    detector::Detector
    coarsedirections::DirectionSet
    coincidencebuilder::L1Builder
    function MuonScanfit(params::MuonScanfitParameters, detector::Detector)
        coincidencebuilder = L1Builder(L1BuilderParameters(params.tmaxlocal, false))
        new(params, detector, DirectionSet(fibonaccisphere(params.α₁), params.α₁), coincidencebuilder)
    end
end
MuonScanfit(det::Detector) = MuonScanfit(MuonScanfitParameters(), det)
function Base.show(io::IO, m::MuonScanfit)
    print(io, "$(typeof(m)) with a coarse scan of $(m.params.α₁)ᵒ and a fine scan of $(m.params.α₂)ᵒ.")
end

"""
Performs a Muon track fit for a given event.
"""
(msf::MuonScanfit)(event::DAQEvent) = msf(event.snapshot_hits)

"""
Performs a Muon track fit for a given set of hits (usually snapshot hits).
"""
function (msf::MuonScanfit)(hits::Vector{T}) where T<:KM3io.AbstractHit

    rhits = msf.coincidencebuilder(HitR1, msf.detector, hits)

    sort!(rhits)
    unique!(h->h.dom_id, rhits)

    clusterize!(rhits, Match3B(msf.params.roadwidth, msf.params.tmaxlocal))

    # First round on 4π
    candidates = scanfit(msf.params, rhits, msf.coarsedirections)
    isempty(candidates) && return candidates
    sort!(candidates, by=m->m.Q; rev=true)

    # Second round on directed cones pointing towards the previous best directions
    # TODO: currently disabled until all the allocations are minimised
    # here, reusing a Vector{Direction} (attached to msf as buffer) might be a good idea.
    # By doing so, we need `fibonaccicone!` and `fibonaccisphere!` as mutating functions
    # directions = Vector{Vector{Direction{Float64}}}()
    # for idx in 1:min(msf.params.nprefits, length(candidates))
    #     most_likely_dir = candidates[idx].dir
    #     push!(directions, fibonaccicone(most_likely_dir, msf.params.α₂, msf.params.θ))
    # end
    # directionset = DirectionSet(vcat(directions...), msf.params.α₂)
    # candidates = scanfit(msf.params, rhits, directionset)

    # isempty(candidates) && return candidates
    # sort!(candidates, by=m->m.Q; rev=true)

    candidates[1:msf.params.nfits]
end

"""

Performs the scanfit for each given direction and returns a
`Vector{MuonScanfitCandidate}` with all successful fits. The resulting vector can
be empty if none of the directions had enough hits to perform the algorithm.

"""
# function scanfit(
#     params::MuonScanfitParameters,
#     rhits::Vector{T},
#     directionset::DirectionSet;
#     nchunks = Threads.nthreads()
# ) where T<:AbstractReducedHit
#     results = [ NeRCA.MuonScanfitCandidate[] for _ in 1:nchunks ]
#     Threads.@sync for (irange, ichunk) in chunks(directionset.directions, nchunks)
#         Threads.@spawn for i in irange
#             xytsolver = XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params.σ)
#             c = directionset.directions[i]
#             push!(results[ichunk], xytsolver(rhits, c, directionset.angular_separation))
#         end
#     end
#     return vcat(results...)
# end
function scanfit(params::MuonScanfitParameters, rhits::Vector{T}, directionset::DirectionSet) where T<:AbstractReducedHit
    chunk_size = max(1, length(directionset.directions) ÷ Threads.nthreads())
    chunks = Iterators.partition(directionset.directions, chunk_size)

    tasks = map(chunks) do chunk
        Threads.@spawn begin
            xytsolver = XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params.σ)
            results = [xytsolver(rhits, c, directionset.angular_separation) for c in chunk]
            results
        end
    end

    mapreduce(fetch, vcat, tasks)
end

struct MuonScanfitCandidate
    pos::Position{Float64}
    dir::Direction{Float64}
    t::Float64
    Q::Float64
    NDF::Int
end

"""
The quality of the fit, the larger the better, as used in e.g. Jpp.
"""
quality(χ², N, NDF) = N  -  0.25 * χ² / NDF

abstract type EstimatorModel end

"""

A straight line parallel to the z-axis.

"""
struct Line1Z <: EstimatorModel
    pos::Position{Float64}
    t::Float64
end
distance(l::Line1Z, pos::Position) = √distancesquared(l, pos)
distancesquared(l::Line1Z, pos::Position) = (pos.x - posx(l))^2 + (pos.y - posy(l))^2
posx(l::Line1Z) = l.pos.x
posy(l::Line1Z) = l.pos.y
posz(l::Line1Z) = l.pos.z
posz(l::Line1Z, pos::Position) = l.pos.z - distance(l, pos) / KM3io.Constants.TAN_THETA_C_WATER
"""

Calculate the Chernkov arrival tive for a given position.

"""
function Base.time(lz::Line1Z, pos::Position)
    v = pos - lz.pos
    R = √(v.x*v.x + v.y*v.y)
    lz.t + (v.z + R * KM3io.Constants.KAPPA_WATER) * KM3io.Constants.C_INVERSE
end


struct SingularSVDException <: Exception
    message::String
end

mutable struct Line1ZEstimator
    model::Line1Z
    V::MMatrix{3, 3, Float64, 9}
    NUMBER_OF_PARAMETERS::Int
    MINIMAL_SVD_WEIGHT::Float64
    function Line1ZEstimator(model::Line1Z)
        V = zero(MMatrix{3, 3, Float64, 9})
        new(model, V, 3, 1.0e-4)
    end
end
posx(est::Line1ZEstimator) = posx(est.model)
posy(est::Line1ZEstimator) = posy(est.model)
posz(est::Line1ZEstimator) = posz(est.model)

function reset!(est::Line1ZEstimator)
    est.V .= 0.0
    est
end

# TODO: generalise for "data" using traits
function estimate!(est::Line1ZEstimator, hits)
    N = length(hits)

    N < est.NUMBER_OF_PARAMETERS && error("Not enough data points, $N points, but we require at least $(est.NUMBER_OF_PARAMETERS)")

    W = 1.0 / N

    pos = sum(h.pos for h ∈ hits) * W
    t = 0.0
    lz = Line1Z(pos, t)

    t₀ = sum(time(h) for h ∈ hits) * W * KM3io.Constants.C

    reset!(est)

    yvec = zeros(MVector{3})
    hit₀ = first(hits)
    xi = hit₀.pos.x - posx(lz)
    yi = hit₀.pos.y - posy(lz)
    ti = (time(hit₀) * KM3io.Constants.C - t₀ - hit₀.pos.z + posz(lz)) / KM3io.Constants.KAPPA_WATER

    # starting from the second hit and including the first in the last iteration
    @inbounds for idx ∈ 2:N+1
        hit = idx > N ? first(hits) : hits[idx]
        xj = hit.pos.x - posx(lz)
        yj = hit.pos.y - posy(lz)
        tj = (time(hit) * KM3io.Constants.C - t₀ - hit.pos.z + posz(lz)) / KM3io.Constants.KAPPA_WATER

        dx = xj - xi
        dy = yj - yi
        dt = ti - tj  # opposite sign

        y = (xj + xi) * dx + (yj + yi) * dy + (tj + ti) * dt

        dx *= 2
        dy *= 2
        dt *= 2

        est.V[1, 1] += dx * dx
        est.V[1, 2] += dx * dy
        est.V[1, 3] += dx * dt
        est.V[2, 2] += dy * dy
        est.V[2, 3] += dy * dt
        est.V[3, 3] += dt * dt

        yvec[1] += dx * y
        yvec[2] += dy * y
        yvec[3] += dt * y

        xi = xj
        yi = yj
        ti = tj
    end

    t₀ *= KM3io.Constants.C_INVERSE

    @inbounds begin
        est.V[2, 1] = est.V[1, 2]
        est.V[3, 1] = est.V[1, 3]
        est.V[3, 2] = est.V[2, 3]
    end


    # Hermitian is needed for typestability!
    wvec, evecs = invert2!(Hermitian(est.V), est.MINIMAL_SVD_WEIGHT)
    yvec2 = (evecs' * yvec)
    yvec2 .*= wvec
    mul!(yvec, evecs, yvec2)
    #yvec = evecs * (diagm(wvec) * (evecs' * yvec))

    @inbounds begin
        est.model = Line1Z(
            Position(
                pos.x + yvec[1],
                pos.y + yvec[2],
                posz(lz)
            ),
            yvec[3] * KM3io.Constants.KAPPA_WATER * KM3io.Constants.C_INVERSE + t₀
        )
    end

    est
end

"""
Invert matrix in-place with a given precision (clamps eigenvalues to 0 below that).
"""
function invert!(V, precision)
    F = svd(V)

    abs(F.S[2]) <  precision * abs(F.S[1]) && throw(SingularSVDException("$F.S"))

    w1 = abs(F.S[1])
    w2 = abs(F.S[2])
    w3 = abs(F.S[3])
    w = max(w1, w2, w3) * precision

    @inbounds for idx in eachindex(F.S)
        F.S[idx] = abs(F.S[idx]) >= w ? 1.0 / F.S[idx] : 0.0
    end

    mul!(V, F.U, diagm(F.S) * F.Vt)
end

@inline function invert2!(V, precision)
    evals, evecs = eigen(V)

    abs(evals[2]) <  precision * abs(evals[1]) && throw(SingularSVDException("$evals"))

    w = maximum(abs, evals) * precision

    wvec = ifelse.(abs.(evals) .>= w, inv.(evals), zero(float(eltype(evals))))

    return wvec, evecs
    #mul!(V, evecs, diagm(wvec) * evecs')
end


struct Variance <: FieldVector{4, Float64}
    x::Float64
    y::Float64
    v::Float64
    w::Float64
end

struct CovMatrix
    M::Matrix{Float64}
    V::Vector{Variance}
    σ::Float64
    CovMatrix(N::Int, σ::Float64) = new(MMatrix{N, N, Float64, N*N}(undef), MVector{N, Variance}(undef), σ)
end

# TODO: generalise hits parameter
function update!(C::CovMatrix, pos::Position, hits, α::Float64)
    N = length(hits)

    ta = deg2rad(α)
    ct = cos(ta)
    st = sin(ta)

    for (idx, hit) ∈ enumerate(hits)
        dx, dy, dz = hit.pos - pos
        R = √(dx^2 + dy^2)

        x = y = ta * KM3io.Constants.KAPPA_WATER * KM3io.Constants.C_INVERSE
        v = w = ta * KM3io.Constants.C_INVERSE

        if R != 0.0
          x *= dx / R
          y *= dy / R
        end

        x *= (dz * ct - dx * st)
        y *= (dz * ct - dy * st)
        v *= -(dx * ct + dz * st)
        w *= -(dy * ct + dz * st)

        C.V[idx] = Variance(x, y, v, w)
    end

    @inbounds for i ∈ 1:N
        @inbounds for j ∈ 1:i
            C.M[i, j] = C.V[i] ⋅ C.V[j]
            C.M[j, i] = C.M[i, j]
        end
        C.M[i, i] = C.V[i] ⋅ C.V[i] + C.σ^2
    end
    C
end

# TODO: generalise hits parameter
timeresvec(lz::Line1Z, hits) = [time(hit) - time(lz, hit.pos) for hit ∈ hits]
function timeresvec!(v::AbstractArray{Float64}, lz::Line1Z, hits)
    for (idx, hit) ∈ enumerate(hits)
        v[idx] = time(hit) - time(lz, hit.pos)
    end
    v
end


"""
A task worker whichs solves for x, y an t for a given set of hits and a direction.
"""
struct XYTSolver
    hits_buffer::Vector{HitR1}
    covmatrix::CovMatrix
    timeresvec::Vector{Float64}
    nmaxhits::Int
    clique::Clique{Match1D}
    est::Line1ZEstimator

    function XYTSolver(nmaxhits::Int, roadwidth::Float64, tmaxlocal::Float64, σ::Float64)
        new(Vector{HitR1}(), CovMatrix(nmaxhits, σ), Vector{Float64}(), nmaxhits, Clique(Match1D(roadwidth, tmaxlocal)),
            Line1ZEstimator(Line1Z(Position(0, 0, 0), 0))
        )
    end
end

function (s::XYTSolver)(hits::Vector{T}, dir::Direction{Float64}, α::Float64) where T<:AbstractReducedHit
    χ² = Inf
    R = rotator(dir)
    n_initial_hits = length(hits)
    resize!(s.hits_buffer, n_initial_hits)

    for (idx, hit) ∈ enumerate(hits) # rotate hits
        s.hits_buffer[idx] = @set hit.pos = R * hit.pos
    end

    if n_initial_hits > s.nmaxhits
        sort!(s.hits_buffer; by=timetoz, alg=PartialQuickSort(s.nmaxhits))
        resize!(s.hits_buffer, s.nmaxhits)
    end

    clusterize!(s.hits_buffer, s.clique)

    hits = s.hits_buffer  # just for convenience
    n_final_hits = length(hits)

    n_final_hits <= s.est.NUMBER_OF_PARAMETERS && return MuonScanfitCandidate(Position(0, 0, 0), dir, 0, -Inf, 0)

    NDF = n_final_hits - s.est.NUMBER_OF_PARAMETERS
    N = hitcount(hits)
    sort!(hits)

    try
        estimate!(s.est, hits)
    catch ex
        # isa(ex, SingularSVDException) && @warn "Singular SVD"
        return MuonScanfitCandidate(Position(0, 0, 0), dir, 0, -Inf, 0)
    end

    # TODO: consider creating a "pos()" getter for everything
    update!(s.covmatrix, s.est.model.pos, hits, α)
    # TODO: this is really ugly... make update!() return the view itself maybe?
    V = view(s.covmatrix.M, 1:n_final_hits, 1:n_final_hits)

    n_final_hits > length(s.timeresvec) && resize!(s.timeresvec, n_final_hits)
    # TODO: better name for this function
    timeresvec!(s.timeresvec, s.est.model, hits)

    #V⁻¹ = inv(V)
    Y = view(s.timeresvec, 1:n_final_hits)  # only take the relevant part of the buffer
    χ² = dot(Y, V \ Y)  # V⁻¹ * Y == V \ Y
    fit_pos = R \ s.est.model.pos

    MuonScanfitCandidate(fit_pos, dir, s.est.model.t, quality(χ², N, NDF), NDF)
end
5 Likes

I have to thank for all your investigations and extremely valuable feedback!

I already spent a lot of time on optimising the linalg stuff but you still found a lot of improvements, I am really amazed :slight_smile:

Here are some benchmark after applying your suggestions step by step (starting with replacing inv… This also contains the second round, where I do the extra fits in the “best fit directions”

julia> @benchmark muons = msfit($raw_hits)
BenchmarkTools.Trial: 169 samples with 1 evaluation.
 Range (min … max):  24.892 ms … 145.685 ms  ┊ GC (min … max): 0.00% … 80.70%
 Time  (median):     27.258 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   29.623 ms ±  11.494 ms  ┊ GC (mean ± σ):  3.69% ±  8.31%

  ▆▇▃▄█▄▂▄▄
  █████████▆▇▅▁▅▁▇▁▅▁▁▁▁▁▁▁▁▅▁▅▅▅▁▁▁▁▁▁▅▁▅▁▁▁▁▅▁▁▅▁▁▅▁▁▁▁▁▁▁▁▆ ▅
  24.9 ms       Histogram: log(frequency) by time        56 ms <

 Memory estimate: 15.47 MiB, allocs estimate: 37414.

julia> @benchmark muons = msfit($raw_hits)
BenchmarkTools.Trial: 194 samples with 1 evaluation.
 Range (min … max):  23.188 ms … 65.265 ms  ┊ GC (min … max): 0.00% … 61.39%
 Time  (median):     25.319 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   25.823 ms ±  4.242 ms  ┊ GC (mean ± σ):  0.80% ±  4.41%

   ▂█▄▃ ▂▅▆▇▄
  ███████████▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▅▁▁▅▁▁▁▁▅▁▅▁▅▁▁▁▁▆▁▁▅▁▁▁▁▁▅ ▅
  23.2 ms      Histogram: log(frequency) by time      40.4 ms <

 Memory estimate: 7.90 MiB, allocs estimate: 33464.

julia> @benchmark muons = msfit($raw_hits)
BenchmarkTools.Trial: 314 samples with 1 evaluation.
 Range (min … max):  15.488 ms …  18.293 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     15.964 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   15.935 ms ± 385.545 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▇▆▂            ▁██▇
  ▆███▅▅▅▅▁▁▁▆█▅▅▆█████▅▃▄▁▄▂▃▂▁▁▂▂▂▃▂▃▁▃▁▂▁▁▁▁▃▂▂▃▃▁▁▂▁▁▃▁▁▁▂ ▃
  15.5 ms         Histogram: frequency by time         17.2 ms <

 Memory estimate: 2.87 MiB, allocs estimate: 15689.

julia> @benchmark muons = msfit($raw_hits)
BenchmarkTools.Trial: 543 samples with 1 evaluation.
 Range (min … max):  8.444 ms … 46.880 ms  ┊ GC (min … max): 0.00% … 81.25%
 Time  (median):     8.881 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   9.217 ms ±  2.201 ms  ┊ GC (mean ± σ):  1.32% ±  4.77%

  ▆▂█▃   ▂
  ████▇▆██▇▅▄▄▃▃▃▂▁▂▁▂▁▁▁▁▂▁▁▁▂▂▁▁▂▁▁▁▁▁▁▁▂▁▂▁▁▁▁▂▁▁▁▂▁▁▁▁▁▂ ▃
  8.44 ms        Histogram: frequency by time          15 ms <

 Memory estimate: 3.34 MiB, allocs estimate: 20111.

and of course the numerical results are unchanged :wink: Here are some preliminary statistics of the angular error (MC truth vs. reconstructed muon direction from muon neutrino charged current events above 100GeV), without any quality cuts:

Screenshot 2023-10-05 at 08.43.28

Most of the improvements can also be applied to the C++ code, so it would be interesting to see the comparison with the applied changed also to that code, but I’ll leave that as a boring exercise to some of the future generation of astroparticle physicists :wink:

Many thanks for all the insights, and especially for your in-depth analysis of the code @abraemer

5 Likes

Will that mapreduce(fetch, vcat, tasks) work as tasks finish, without waiting all to end? I’m trying to understand why that version is faster than the manual syncing.

For instance I’m surprised that your version is faster than this one:

 function scanfit(
     params::MuonScanfitParameters,
     rhits::Vector{T},
     directionset::DirectionSet
 ) where T<:AbstractReducedHit
     results = Vector{NeRCA.MuonScanfitCandidate}(undef, length(directionset.directions))
     Threads.@threads for i in eachindex(directionset.directions)
         xytsolver = XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params. σ)
         c = directionset.directions[i]
         results[i] = xytsolver(rhits, c, directionset.angular_separation)
     end
     return results
 end

(I wouldn’t expect it to be slower, but not faster either, at least not significantly, and effectively it is running faster)

ps: Anyway the OP can localize the type-instability inside the scanfit function in the mapreduce version converting the result with:

    mapreduce(fetch, vcat, tasks)::Vector{NeRCA.MuonScanfitCandidate}
end

probably that’s a good practice, to avoid problems.

Your second version does not work in chunks right? It creates a new XYTSolver for each entry in directionset.directions.

I think your previous version just struggles with type-instability from the vcat(results...) but I haven’t checked so I might be dead wrong :slight_smile:

Conceptually what this code does:

  • Split range into chunk
  • Spawn a Task (which will be worked on by some thread) for each chunk.
  • That task allocates a XYTSolver once and then uses it to solve all instances within the chunk
  • mapreduce(fetch, vcat, tasks) just will go through all task one-by-one, wait for it finish and vcat the results as it goes

Tbh I didn’t test multithreading performance but single-threaded your first variant took 2.5x the time of my mapreduce variant.

2 Likes

Oh, yes, I didn’t see that, even in the previous version. The correct chunked version with ChunkSplitters, putting the generation of xytsolver outside the inner loop would be:

 using ChunkSplitters
 function scanfit(
     params::MuonScanfitParameters,
     rhits::Vector{T},
     directionset::DirectionSet;
     nchunks = Threads.nthreads()
 ) where T<:AbstractReducedHit
     results = [ NeRCA.MuonScanfitCandidate[] for _ in 1:nchunks ]
     Threads.@sync for (irange, ichunk) in chunks(directionset.directions, nchunks)
         Threads.@spawn begin
             xytsolver = XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params.σ)
             for i in irange
                 c = directionset.directions[i]
                 push!(results[ichunk], xytsolver(rhits, c, directionset.angular_separation))
             end
          end
     end
     return vcat(results...)
 end

and that recovers the performance you have in your version.

And this avoids the vcat thus allocating a little less (probably not important):

code
 using ChunkSplitters
 function scanfit(
     params::MuonScanfitParameters,
     rhits::Vector{T},
     directionset::DirectionSet;
     nchunks = Threads.nthreads()
 ) where T<:AbstractReducedHit
     results = Vector{NeRCA.MuonScanfitCandidate}(undef, length(directionset.directions))
     Threads.@sync for (irange, _) in chunks(directionset.directions, nchunks)
         Threads.@spawn begin
             xytsolver = XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params.σ)
             for i in irange
                 c = directionset.directions[i]
                 results[i] = xytsolver(rhits, c, directionset.angular_separation)
             end
          end
     end
     return results
 end
1 Like