Thanks for the runnable example
Baseline
- Essentially all time is spent inside XYTSolver (~90%)
- Low hanging fruit: replace
χ² = transpose(Y) * V⁻¹ * Y
with χ² = dot(Y, V⁻¹, Y)
to save an allocation and slightly more speed
- Now:
clusterize!(s.hits_buffer, s.matcher)
takes 50% of the total time,
estimate!(s.est, hits)
takes 22%
V⁻¹ = inv(V)
takes 10%
In reverse order:
V⁻¹ = inv(V)
- Don’t just compute the inverse matrix and then apply it to a single vector. Better do “division” like
V⁻¹ * Y == V \ Y
, so we have χ² = dot(Y, V \ Y)
7.5% speedup
estimate
Almost all time is spent inside invert!
, StaticArrays.jl has no specialized svd
but there is a specialized eigen
so use that instead. Then instead of multiply the matrices to get the pseudoinverse, just return the individual components, because we can do the application to vector with just matrix-vector products. Also I replaced your y₀,y₁,y₂
with yvec = zeros(MVector{3})
, this also simplifies the construction of the model.
clusterize!
This is I find a bit trickier. Maybe there are algorithmic optimizations, but I saw nothing obvious. Most time is spent inside the match
as there are quadratically many calls. I found that calling time
inside the match functions was the bottleneck, so I precompute all the times
and pass the time corresponding to the hit
also into match
. That got me ~60% speedup
Another low-hanging fruit is to reuse the Clique
inside XYTSolver
to save some allocations and gain a very minor speedup.
I think @Imiq saw a typeinstability introduced by Iterators.flatten
. I tried Imiq’s version and is was a lot slower than my suggested version with mapreduce
(and needs another dependency).
Summary
All in all I achieved a speedup of 2.3x. clusterize!
still takes ~60% of the time (with 2/3s of the time spent inside match
), the V\Y
takes ~15%, and estimate!´/
invert2!` also takes 15%.
hits.jl
struct Hit <: KM3io.AbstractHit
t::Float64
tot::Float64
end
Base.isless(lhs::Hit, rhs::Hit) = lhs.t < rhs.t
struct HitL0 <: KM3io.AbstractHit
channel_id::Int8
t::Float64
tot::Float64
pos::Position{Float64}
dir::Direction{Float64}
end
Base.isless(lhs::HitL0, rhs::HitL0) = time(lhs) < time(rhs)
abstract type AbstractSpecialHit <: KM3io.AbstractHit end # TODO: bad naming ;)
abstract type AbstractCoincidenceHit <: AbstractSpecialHit end
abstract type AbstractReducedHit <: AbstractSpecialHit end
struct HitL1 <: AbstractCoincidenceHit
dom_id::Int32
hits::Vector{HitL0}
end
# function Base.time(h::HitL1)
# # n = length(h)
# # n > length(SLEWS_L1) && return SLEWS_L1[end]
# # time(first(h.hits)) - SLEWS_L1[n-1]
# time(first(h.hits))
# end
# position(h::HitL1) = first(h.hits).pos
# function tot(h::HitL1)
# combined_hit = combine(h.hits)
# combined_hit.tot
# end
struct HitL2 <: AbstractCoincidenceHit
dom_id::Int32
hits::Vector{HitL0}
end
HitL1(m::DetectorModule, hits) = HitL1(m.id, hits)
HitL2(m::DetectorModule, hits) = HitL2(m.id, hits)
Base.length(c::AbstractCoincidenceHit) = length(c.hits)
Base.eltype(::AbstractCoincidenceHit) = HitL0
function Base.iterate(c::AbstractCoincidenceHit, state=1)
@inbounds state > length(c) ? nothing : (c.hits[state], state+1)
end
struct HitR0 <: AbstractReducedHit
t::Float64
tot::Float64
channel_id::Int8
end
Base.isless(lhs::HitR0, rhs::HitR0) = time(lhs) < time(rhs)
struct HitR1 <: AbstractReducedHit
dom_id::Int32
pos::Position{Float64}
t::Float64
tot::Float64
n::Int
weight::Float64
end
Base.isless(lhs::HitR1, rhs::HitR1) = lhs.dom_id == rhs.dom_id ? time(lhs) < time(rhs) : lhs.dom_id < rhs.dom_id
const HitR2 = HitR1
function HitR1(dom_id::Integer, hits::Vector{HitL0})
combined_hit = combine(hits)
h = first(hits)
count = weight = length(hits)
HitR1(dom_id, h.pos, combined_hit.t, combined_hit.tot, count, weight)
end
# function HitR1(m::DetectorModule, hit::HitL1)
# count = weight = length(hit)
# HitR1(m.id, position(hit), time(hit), tot(hit), count, weight)
# end
weight(h::HitR1) = h.weight
starttime(hit) = time(hit)
endtime(hit) = time(hit) + hit.tot
"""
Return the total number of hits for a collection of reduced hits.
"""
hitcount(hits::AbstractArray{T}) where T<:AbstractReducedHit = sum(h.n for h in hits)
"""
Combine snapshot and triggered hits to a single hits-vector.
This should be used to transfer the trigger information to the
snapshot hits from a DAQEvent. The triggered hits are a subset
of the snapshot hits.
"""
function combine(snapshot_hits::Vector{KM3io.SnapshotHit}, triggered_hits::Vector{KM3io.TriggeredHit})
triggermasks = Dict{Tuple{UInt8, Int32, Int32, UInt8}, Int64}()
for hit ∈ triggered_hits
triggermasks[(hit.channel_id, hit.dom_id, hit.t, hit.tot)] = hit.trigger_mask
end
n = length(snapshot_hits)
hits = sizehint!(Vector{TriggeredHit}(), n)
for hit in snapshot_hits
channel_id = hit.channel_id
dom_id = hit.dom_id
t = hit.t
tot = hit.tot
triggermask = get(triggermasks, (channel_id, dom_id, t, tot), 0)
push!(hits, TriggeredHit(dom_id, channel_id, t, tot, triggermask))
end
hits
end
"""
Create a `Vector` with hits contributing to `n`-fold coincidences within a time
window of Δt.
"""
function nfoldhits(hits::Vector{T}, Δt, n) where {T<:KM3io.AbstractDAQHit}
hit_map = modulemap(hits)
chits = Vector{T}()
for (dom_id, dom_hits) ∈ hit_map
bag = Vector{T}()
push!(bag, dom_hits[1])
t0 = dom_hits[1].t
for hit in dom_hits[2:end]
if hit.t - t0 > Δt
if length(bag) >= n
append!(chits, bag)
end
bag = Vector{T}()
end
push!(bag, hit)
t0 = hit.t
end
end
return chits
end
"""
Calculate the multiplicities for a given time window. Two arrays are
are returned, one contains the multiplicities, the second one the IDs
of the coincidence groups.
The hits should be sorted by time and then by dom_id.
"""
function count_multiplicities(hits::Vector{T}, tmax=20) where {T<:KM3io.AbstractHit}
n = length(hits)
mtp = ones(Int32, n)
cid = zeros(Int32, n)
idx0 = 1
_mtp = 1
_cid = idx0
t0 = hits[idx0].t
dom_id = hits[idx0].dom_id
for i in 2:n
hit = hits[i]
if hit.dom_id != dom_id
mtp[idx0:i-1] .= _mtp
cid[idx0:i-1] .= _cid
dom_id = hit.dom_id
t0 = hit.t
_mtp = 1
_cid += 1
idx0 = i
continue
end
Δt = hit.t - t0
if Δt > tmax
mtp[idx0:i] .= _mtp
cid[idx0:i] .= _cid
_mtp = 0
_cid += 1
idx0 = i
t0 = hit.t
end
_mtp += 1
end
mtp[idx0:end] .= _mtp
cid[idx0:end] .= _cid
mtp, cid
end
"""
Counts the multiplicities and modifies the .multiplicity field of the hits.
Important: the hits have to be sorted by time and then by DOM ID first.
"""
function count_multiplicities!(hits::Vector{KM3io.XCalibratedHit}, tmax=20)
_mtp = 0
_cid = 0
t0 = 0
dom_id = 0
hit_buffer = Vector{XCalibratedHit}()
function process_buffer()
while !isempty(hit_buffer)
_hit = pop!(hit_buffer)
_hit.multiplicity.count = _mtp
_hit.multiplicity.id = _cid
end
end
function reset()
_mtp = 1
_cid += 1
end
for hit ∈ hits
if length(hit_buffer) == 0
reset()
push!(hit_buffer, hit)
t0 = hit.t
dom_id = hit.dom_id
continue
end
if hit.dom_id != dom_id
process_buffer()
push!(hit_buffer, hit)
t0 = hit.t
dom_id = hit.dom_id
reset()
continue
end
Δt = hit.t - t0
if Δt > tmax
process_buffer()
push!(hit_buffer, hit)
t0 = hit.t
reset()
else
push!(hit_buffer, hit)
_mtp += 1
end
end
if length(hit_buffer) > 0
process_buffer()
end
return
end
"""
Categorise hits by DU and put them into a dictionary of DU=>Vector{Hit}.
Caveat: this function is not typesafe, only suited for high-level analysis (like plots).
"""
@inline duhits(hits::Vector{T}) where {T<:KM3io.XCalibratedHit} = categorize(:du, hits)
"""
Return a vector of hits with ToT >= `tot`.
"""
function totcut(hits::Vector{T}, tot) where {T<:KM3io.AbstractDAQHit}
return filter(h->h.tot >= tot, hits)
end
"""
Returns the estimated number of photoelectrons for a given ToT.
"""
function nphes(tot)
if tot <= 20
return 0.0
end
if tot <= 26
return 1.0
end
if tot < 170
return 1.0 + (tot - 26)/(1/0.28)
end
return 40.0 + (255 - tot)*2.0
end
"""
Creates a map (`Dict{Int32, Vector{T}}`) from a flat `Vector{T}` split up based
on the `dom_id` of each element. A typical use is to split up a vector of hits
by their optical module IDs.
This function is similar to `categorize(:dom_id, Vector{T})` but this method
is completely typesafe.
"""
function modulemap(hits::Vector{T}) where T
out = Dict{Int32, Vector{T}}()
for hit ∈ hits
if !(hit.dom_id ∈ keys(out))
out[hit.dom_id] = T[]
end
push!(out[hit.dom_id], hit)
end
out
end
"""
Calibrates hits.
"""
function KM3io.calibrate(T::Type{HitR1}, det::Detector, hits)
rhits = sizehint!(Vector{T}(), length(hits))
for hit ∈ hits
pmt = det[hit.dom_id][hit.channel_id]
t = hit.t + pmt.t₀
push!(rhits, T(hit.dom_id, pmt.pos, t, hit.tot, 1, 0))
end
rhits
end
function KM3io.calibrate(T::Type{HitL0}, m::DetectorModule, hits)
chits = sizehint!(Vector{T}(), length(hits))
for hit ∈ hits
pmt = m[hit.channel_id]
t = hit.t + pmt.t₀
push!(chits, T(hit.channel_id, t, hit.tot, pmt.pos, pmt.dir))
end
chits
end
"""
Combines several hits into a single one by taking the earliest start time,
then latest endtime and a ToT which spans over the time range.
"""
function combine(hits::Vector{T}) where T <: KM3io.AbstractHit
isempty(hits) && return Hit(0.0, 0.0)
t1 = starttime(first(hits))
t2 = endtime(first(hits))
@inbounds for hit ∈ hits[2:end]
_t1 = starttime(hit)
_t2 = endtime(hit)
if t1 > _t1
t1 = _t1
end
if t2 < _t2
t2 = _t2
end
end
Hit(t1, t2 - t1)
end
struct L1BuilderParameters
Δt::Float64
combine::Bool
end
struct L1Builder
params::L1BuilderParameters
end
"""
Find coincidences within the time window `Δt` of the initialised `params`. The return
value is a vector of `L1Hit`s.
"""
function (b::L1Builder)(::Type{H}, det::Detector, hits::Vector{T}; combine=true) where {T, H}
out = H[]
mm = modulemap(hits)
for (m, module_hits) ∈ mm
_findL1!(out, det[m], module_hits, b.params.Δt, combine)
end
out
end
(b::L1Builder)(det::Detector, hits) = b(HitL1, det, hits)
function _findL1!(out::Vector{H}, m::DetectorModule, hits, Δt, combine::Bool) where H <: AbstractSpecialHit
n = length(hits)
n < 2 && return out
chits = calibrate(HitL0, m, hits)
sort!(chits)
ref_idx = 1 # starting with the first hit, obviously
idx = 2 # first comparison is the second hit
while ref_idx <= n
restart = false
idx = ref_idx + 1
while(idx <= n+1 || restart) # n+1 to do the final loop after the last hit
if idx > n || (time(chits[idx]) - time(chits[ref_idx]) > Δt)
end_idx = idx - 1
# check if we have gathered some hits
if ref_idx != end_idx
coincident_hits = [chits[i] for i ∈ ref_idx:end_idx]
# push!(out, H(m, HitL1(m.id, coincident_hits)))
push!(out, H(m.id, coincident_hits))
if combine
ref_idx = end_idx
end
end
restart = true
break
end
idx += 1
end
ref_idx += 1
end
out
end
struct L2BuilderParameters
n_hits::Int
Δt::Float64
ctmin::Float64
end
struct L2Builder
params::L2BuilderParameters
end
function (b::L2Builder)(hits)
out = []
for hit ∈ hits
end
error("Not implemented yet")
end
const SLEWS_L1 = SVector(
+0.00, +0.39, +0.21, -0.59, -1.15,
-1.59, -1.97, -2.30, -2.56, -2.89,
-3.12, -3.24, -3.56, -3.69, -4.00,
-4.10, -4.16, -4.49, -4.71, -4.77,
-4.81, -4.87, -4.88, -4.83, -5.21,
-5.06, -5.27, -5.18, -5.24, -5.79,
-6.78, -6.24
)
function KM3io.slew(h::HitR1)
h.n > length(SLEWS_L1) && return SLEWS_L1[end]
SLEWS_L1[h.n]
end
"""
Calculates the time to reach the z-position of the `hit` along the z-axis.
"""
function timetoz(hit)
time(hit) * KM3io.Constants.C - hit.pos.z
end
abstract type AbstractMatcher end
"""
3D match criterion with road width, intended for muon signals.
Origin: B. Bakker, "Trigger studies for the Antares and KM3NeT detector.",
Master thesis, University of Amsterdam. With modifications from Jpp (M. de Jong).
"""
mutable struct Match3B <: AbstractMatcher
const roadwidth::Float64
const tmaxextra::Float64
x::Float64
y::Float64
z::Float64
d::Float64
t::Float64
dmin::Float64
dmax::Float64
d₂::Float64
const D0::Float64
const D1::Float64
const D2::Float64
const D02::Float64
const D12::Float64
const D22::Float64
const Rs2::Float64
const Rst::Float64
const Rt::Float64
const R2::Float64
function Match3B(roadwidth, tmaxextra=0.0)
tt2 = KM3io.Constants.TAN_THETA_C_WATER^2
D0 = roadwidth
D1 = roadwidth * 2.0
# calculation D2 in thesis is wrong, here correct (taken from Jpp/JMatch3B)
D2 = roadwidth * 0.5 * sqrt(tt2 + 10.0 + 9.0 / tt2)
D02 = D0 * D0
D12 = D1 * D1
D22 = D2 * D2
R = roadwidth
Rs = R * KM3io.Constants.SIN_THETA_C_WATER
R2 = R * R;
Rs2 = Rs * Rs;
Rst = Rs * KM3io.Constants.TAN_THETA_C_WATER
Rt = R * KM3io.Constants.TAN_THETA_C_WATER
new(
roadwidth, tmaxextra, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
D0, D1, D2, D02, D12, D22, Rs2, Rst, Rt, R2,
)
end
end
Base.show(io::IO, m::Match3B) = print(io, "Match3B($(m.roadwidth), $(m.tmaxextra))")
function (m::Match3B)(hit1, hit2, time1, time2)
m.x = hit1.pos.x - hit2.pos.x
m.y = hit1.pos.y - hit2.pos.y
m.z = hit1.pos.z - hit2.pos.z
m.d₂ = m.x * m.x + m.y * m.y + m.z * m.z
m.t = abs(time1 - time2)
if (m.d₂ < m.D02)
m.dmax = √m.d₂ * KM3io.Constants.INDEX_OF_REFRACTION_WATER
else
m.dmax = √(m.d₂ - m.Rs2) + m.Rst
end
m.t > m.dmax * KM3io.Constants.C_INVERSE + m.tmaxextra && return false
if m.d₂ > m.D22
m.dmin = √(m.d₂ - m.R2) - m.Rt
elseif m.d₂ > m.D12
m.dmin = √(m.d₂ - m.D12)
else
return true
end
m.t >= m.dmin * KM3io.Constants.C_INVERSE - m.tmaxextra
end
"""
Simple Cherenkov matcher for muon signals. The muon is assumed to travel parallel
to the Z-axis.
"""
mutable struct Match1D <: AbstractMatcher
const roadwidth::Float64 # maximal road width [ns]
const tmaxextra::Float64 # maximal extra time [ns]
const tmax::Float64
x::Float64
y::Float64
z::Float64
d::Float64
t::Float64
function Match1D(roadwidth, tmaxextra=0.0)
tmax = 0.5 * roadwidth * KM3io.Constants.TAN_THETA_C_WATER * KM3io.Constants.C_INVERSE + tmaxextra
new(roadwidth, tmaxextra, tmax, 0.0, 0.0, 0.0, 0.0, 0.0)
end
end
function (m::Match1D)(hit1, hit2, time1, time2)
m.z = hit1.pos.z - hit2.pos.z
#m.t = abs(time(hit1) - time(hit2) - m.z * KM3io.Constants.C_INVERSE)
m.t = abs(time1 - time2 - m.z * KM3io.Constants.C_INVERSE)
m.t > m.tmax && return false
x = hit1.pos.x - hit2.pos.x
y = hit1.pos.y - hit2.pos.y
d = √(x^2 + y^2)
if d <= m.roadwidth/2
return m.t <= d * KM3io.Constants.TAN_THETA_C_WATER * KM3io.Constants.C_INVERSE + m.tmaxextra
elseif d <= m.roadwidth
return m.t <= (m.roadwidth - d) * KM3io.Constants.TAN_THETA_C_WATER * KM3io.Constants.C_INVERSE + m.tmaxextra
end
false
end
@inline function swap!(arr::AbstractArray, i::Int, j::Int)
arr[i], arr[j] = arr[j], arr[i]
arr
end
"""
Clique clusterizer which takes a matcher algorithm like `Match3B` as input.
"""
struct Clique{T<:AbstractMatcher}
match::T
weights::Vector{Float64}
times::Vector{Float64}
Clique(m::T) where T = new{T}(m, Float64[], Float64[])
end
"""
Applies the clique clusterization algorithm and leaves only the best matching
hits in the input array.
"""
clusterize!(hits::AbstractArray{T}, m::AbstractMatcher) where T<:AbstractSpecialHit = clusterize!(hits, Clique(m))
function clusterize!(hits::AbstractArray{T}, c::Clique) where T<:AbstractSpecialHit
N = length(hits)
N == 0 && return hits
times = c.times
resize!(c.weights, N)
resize!(c.times, N)
@inbounds for i ∈ 1:N
c.weights[i] = weight(hits[i])
times[i] = time(hits[i])
end
@inbounds for i ∈ 1:N
@inbounds for j ∈ i:N
j == i && continue
if c.match(hits[i], hits[j], times[i], times[j])
c.weights[i] += weight(hits[j])
c.weights[j] += weight(hits[i])
end
end
end
# Remove hit with the smallest weight of associated hits.
# This procedure stops when the weight of associated hits
# is equal to the maximal weight of (remaining) hits.
n = N
# @show N
@inbounds while true
j = 1
W = c.weights[j]
# @show W
@inbounds for i ∈ 2:n
if c.weights[i] < c.weights[j]
j = i
elseif c.weights[i] > W
W = c.weights[i]
end
end
# end condition
c.weights[j] == W && return resize!(hits, n)
# Swap the selected hit to end.
swap!(hits, j, n)
swap!(c.weights, j, n)
swap!(times, j, n)
# Decrease weight of associated hits for each associated hit.
@inbounds for i ∈ 1:n
c.weights[n] <= weight(hits[n]) && break
if c.match(hits[i], hits[n], times[i], times[n])
c.weights[i] -= weight(hits[n])
c.weights[n] -= weight(hits[i])
end
end
n -= 1
end
end
scanfit.jl
Base.@kwdef struct MuonScanfitParameters
tmaxlocal::Float64 = 18.0 # [ns]
roadwidth::Float64 = 200.0 # [m]
nmaxhits::Int = 50 # maximum number of hits to use
nfits::Int = 1
nprefits::Int = 10
σ::Float64 = 5.0 # [ns]
α₁::Float64 = 7.0 # grid angle of the coarse scan
α₂::Float64 = 0.5 # grid angle of the fine scan
θ::Float64 = 3.5 # opening angle of the fine-scan cone
end
"""
A container of directions with additionial information about their median
angular separation.
"""
struct DirectionSet
directions::Vector{Direction{Float64}}
angular_separation::Float64
end
struct MuonScanfit
params::MuonScanfitParameters
detector::Detector
coarsedirections::DirectionSet
coincidencebuilder::L1Builder
function MuonScanfit(params::MuonScanfitParameters, detector::Detector)
coincidencebuilder = L1Builder(L1BuilderParameters(params.tmaxlocal, false))
new(params, detector, DirectionSet(fibonaccisphere(params.α₁), params.α₁), coincidencebuilder)
end
end
MuonScanfit(det::Detector) = MuonScanfit(MuonScanfitParameters(), det)
function Base.show(io::IO, m::MuonScanfit)
print(io, "$(typeof(m)) with a coarse scan of $(m.params.α₁)ᵒ and a fine scan of $(m.params.α₂)ᵒ.")
end
"""
Performs a Muon track fit for a given event.
"""
(msf::MuonScanfit)(event::DAQEvent) = msf(event.snapshot_hits)
"""
Performs a Muon track fit for a given set of hits (usually snapshot hits).
"""
function (msf::MuonScanfit)(hits::Vector{T}) where T<:KM3io.AbstractHit
rhits = msf.coincidencebuilder(HitR1, msf.detector, hits)
sort!(rhits)
unique!(h->h.dom_id, rhits)
clusterize!(rhits, Match3B(msf.params.roadwidth, msf.params.tmaxlocal))
# First round on 4π
candidates = scanfit(msf.params, rhits, msf.coarsedirections)
isempty(candidates) && return candidates
sort!(candidates, by=m->m.Q; rev=true)
# Second round on directed cones pointing towards the previous best directions
# TODO: currently disabled until all the allocations are minimised
# here, reusing a Vector{Direction} (attached to msf as buffer) might be a good idea.
# By doing so, we need `fibonaccicone!` and `fibonaccisphere!` as mutating functions
# directions = Vector{Vector{Direction{Float64}}}()
# for idx in 1:min(msf.params.nprefits, length(candidates))
# most_likely_dir = candidates[idx].dir
# push!(directions, fibonaccicone(most_likely_dir, msf.params.α₂, msf.params.θ))
# end
# directionset = DirectionSet(vcat(directions...), msf.params.α₂)
# candidates = scanfit(msf.params, rhits, directionset)
# isempty(candidates) && return candidates
# sort!(candidates, by=m->m.Q; rev=true)
candidates[1:msf.params.nfits]
end
"""
Performs the scanfit for each given direction and returns a
`Vector{MuonScanfitCandidate}` with all successful fits. The resulting vector can
be empty if none of the directions had enough hits to perform the algorithm.
"""
# function scanfit(
# params::MuonScanfitParameters,
# rhits::Vector{T},
# directionset::DirectionSet;
# nchunks = Threads.nthreads()
# ) where T<:AbstractReducedHit
# results = [ NeRCA.MuonScanfitCandidate[] for _ in 1:nchunks ]
# Threads.@sync for (irange, ichunk) in chunks(directionset.directions, nchunks)
# Threads.@spawn for i in irange
# xytsolver = XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params.σ)
# c = directionset.directions[i]
# push!(results[ichunk], xytsolver(rhits, c, directionset.angular_separation))
# end
# end
# return vcat(results...)
# end
function scanfit(params::MuonScanfitParameters, rhits::Vector{T}, directionset::DirectionSet) where T<:AbstractReducedHit
chunk_size = max(1, length(directionset.directions) ÷ Threads.nthreads())
chunks = Iterators.partition(directionset.directions, chunk_size)
tasks = map(chunks) do chunk
Threads.@spawn begin
xytsolver = XYTSolver(params.nmaxhits, params.roadwidth, params.tmaxlocal, params.σ)
results = [xytsolver(rhits, c, directionset.angular_separation) for c in chunk]
results
end
end
mapreduce(fetch, vcat, tasks)
end
struct MuonScanfitCandidate
pos::Position{Float64}
dir::Direction{Float64}
t::Float64
Q::Float64
NDF::Int
end
"""
The quality of the fit, the larger the better, as used in e.g. Jpp.
"""
quality(χ², N, NDF) = N - 0.25 * χ² / NDF
abstract type EstimatorModel end
"""
A straight line parallel to the z-axis.
"""
struct Line1Z <: EstimatorModel
pos::Position{Float64}
t::Float64
end
distance(l::Line1Z, pos::Position) = √distancesquared(l, pos)
distancesquared(l::Line1Z, pos::Position) = (pos.x - posx(l))^2 + (pos.y - posy(l))^2
posx(l::Line1Z) = l.pos.x
posy(l::Line1Z) = l.pos.y
posz(l::Line1Z) = l.pos.z
posz(l::Line1Z, pos::Position) = l.pos.z - distance(l, pos) / KM3io.Constants.TAN_THETA_C_WATER
"""
Calculate the Chernkov arrival tive for a given position.
"""
function Base.time(lz::Line1Z, pos::Position)
v = pos - lz.pos
R = √(v.x*v.x + v.y*v.y)
lz.t + (v.z + R * KM3io.Constants.KAPPA_WATER) * KM3io.Constants.C_INVERSE
end
struct SingularSVDException <: Exception
message::String
end
mutable struct Line1ZEstimator
model::Line1Z
V::MMatrix{3, 3, Float64, 9}
NUMBER_OF_PARAMETERS::Int
MINIMAL_SVD_WEIGHT::Float64
function Line1ZEstimator(model::Line1Z)
V = zero(MMatrix{3, 3, Float64, 9})
new(model, V, 3, 1.0e-4)
end
end
posx(est::Line1ZEstimator) = posx(est.model)
posy(est::Line1ZEstimator) = posy(est.model)
posz(est::Line1ZEstimator) = posz(est.model)
function reset!(est::Line1ZEstimator)
est.V .= 0.0
est
end
# TODO: generalise for "data" using traits
function estimate!(est::Line1ZEstimator, hits)
N = length(hits)
N < est.NUMBER_OF_PARAMETERS && error("Not enough data points, $N points, but we require at least $(est.NUMBER_OF_PARAMETERS)")
W = 1.0 / N
pos = sum(h.pos for h ∈ hits) * W
t = 0.0
lz = Line1Z(pos, t)
t₀ = sum(time(h) for h ∈ hits) * W * KM3io.Constants.C
reset!(est)
yvec = zeros(MVector{3})
hit₀ = first(hits)
xi = hit₀.pos.x - posx(lz)
yi = hit₀.pos.y - posy(lz)
ti = (time(hit₀) * KM3io.Constants.C - t₀ - hit₀.pos.z + posz(lz)) / KM3io.Constants.KAPPA_WATER
# starting from the second hit and including the first in the last iteration
@inbounds for idx ∈ 2:N+1
hit = idx > N ? first(hits) : hits[idx]
xj = hit.pos.x - posx(lz)
yj = hit.pos.y - posy(lz)
tj = (time(hit) * KM3io.Constants.C - t₀ - hit.pos.z + posz(lz)) / KM3io.Constants.KAPPA_WATER
dx = xj - xi
dy = yj - yi
dt = ti - tj # opposite sign
y = (xj + xi) * dx + (yj + yi) * dy + (tj + ti) * dt
dx *= 2
dy *= 2
dt *= 2
est.V[1, 1] += dx * dx
est.V[1, 2] += dx * dy
est.V[1, 3] += dx * dt
est.V[2, 2] += dy * dy
est.V[2, 3] += dy * dt
est.V[3, 3] += dt * dt
yvec[1] += dx * y
yvec[2] += dy * y
yvec[3] += dt * y
xi = xj
yi = yj
ti = tj
end
t₀ *= KM3io.Constants.C_INVERSE
@inbounds begin
est.V[2, 1] = est.V[1, 2]
est.V[3, 1] = est.V[1, 3]
est.V[3, 2] = est.V[2, 3]
end
# Hermitian is needed for typestability!
wvec, evecs = invert2!(Hermitian(est.V), est.MINIMAL_SVD_WEIGHT)
yvec2 = (evecs' * yvec)
yvec2 .*= wvec
mul!(yvec, evecs, yvec2)
#yvec = evecs * (diagm(wvec) * (evecs' * yvec))
@inbounds begin
est.model = Line1Z(
Position(
pos.x + yvec[1],
pos.y + yvec[2],
posz(lz)
),
yvec[3] * KM3io.Constants.KAPPA_WATER * KM3io.Constants.C_INVERSE + t₀
)
end
est
end
"""
Invert matrix in-place with a given precision (clamps eigenvalues to 0 below that).
"""
function invert!(V, precision)
F = svd(V)
abs(F.S[2]) < precision * abs(F.S[1]) && throw(SingularSVDException("$F.S"))
w1 = abs(F.S[1])
w2 = abs(F.S[2])
w3 = abs(F.S[3])
w = max(w1, w2, w3) * precision
@inbounds for idx in eachindex(F.S)
F.S[idx] = abs(F.S[idx]) >= w ? 1.0 / F.S[idx] : 0.0
end
mul!(V, F.U, diagm(F.S) * F.Vt)
end
@inline function invert2!(V, precision)
evals, evecs = eigen(V)
abs(evals[2]) < precision * abs(evals[1]) && throw(SingularSVDException("$evals"))
w = maximum(abs, evals) * precision
wvec = ifelse.(abs.(evals) .>= w, inv.(evals), zero(float(eltype(evals))))
return wvec, evecs
#mul!(V, evecs, diagm(wvec) * evecs')
end
struct Variance <: FieldVector{4, Float64}
x::Float64
y::Float64
v::Float64
w::Float64
end
struct CovMatrix
M::Matrix{Float64}
V::Vector{Variance}
σ::Float64
CovMatrix(N::Int, σ::Float64) = new(MMatrix{N, N, Float64, N*N}(undef), MVector{N, Variance}(undef), σ)
end
# TODO: generalise hits parameter
function update!(C::CovMatrix, pos::Position, hits, α::Float64)
N = length(hits)
ta = deg2rad(α)
ct = cos(ta)
st = sin(ta)
for (idx, hit) ∈ enumerate(hits)
dx, dy, dz = hit.pos - pos
R = √(dx^2 + dy^2)
x = y = ta * KM3io.Constants.KAPPA_WATER * KM3io.Constants.C_INVERSE
v = w = ta * KM3io.Constants.C_INVERSE
if R != 0.0
x *= dx / R
y *= dy / R
end
x *= (dz * ct - dx * st)
y *= (dz * ct - dy * st)
v *= -(dx * ct + dz * st)
w *= -(dy * ct + dz * st)
C.V[idx] = Variance(x, y, v, w)
end
@inbounds for i ∈ 1:N
@inbounds for j ∈ 1:i
C.M[i, j] = C.V[i] ⋅ C.V[j]
C.M[j, i] = C.M[i, j]
end
C.M[i, i] = C.V[i] ⋅ C.V[i] + C.σ^2
end
C
end
# TODO: generalise hits parameter
timeresvec(lz::Line1Z, hits) = [time(hit) - time(lz, hit.pos) for hit ∈ hits]
function timeresvec!(v::AbstractArray{Float64}, lz::Line1Z, hits)
for (idx, hit) ∈ enumerate(hits)
v[idx] = time(hit) - time(lz, hit.pos)
end
v
end
"""
A task worker whichs solves for x, y an t for a given set of hits and a direction.
"""
struct XYTSolver
hits_buffer::Vector{HitR1}
covmatrix::CovMatrix
timeresvec::Vector{Float64}
nmaxhits::Int
clique::Clique{Match1D}
est::Line1ZEstimator
function XYTSolver(nmaxhits::Int, roadwidth::Float64, tmaxlocal::Float64, σ::Float64)
new(Vector{HitR1}(), CovMatrix(nmaxhits, σ), Vector{Float64}(), nmaxhits, Clique(Match1D(roadwidth, tmaxlocal)),
Line1ZEstimator(Line1Z(Position(0, 0, 0), 0))
)
end
end
function (s::XYTSolver)(hits::Vector{T}, dir::Direction{Float64}, α::Float64) where T<:AbstractReducedHit
χ² = Inf
R = rotator(dir)
n_initial_hits = length(hits)
resize!(s.hits_buffer, n_initial_hits)
for (idx, hit) ∈ enumerate(hits) # rotate hits
s.hits_buffer[idx] = @set hit.pos = R * hit.pos
end
if n_initial_hits > s.nmaxhits
sort!(s.hits_buffer; by=timetoz, alg=PartialQuickSort(s.nmaxhits))
resize!(s.hits_buffer, s.nmaxhits)
end
clusterize!(s.hits_buffer, s.clique)
hits = s.hits_buffer # just for convenience
n_final_hits = length(hits)
n_final_hits <= s.est.NUMBER_OF_PARAMETERS && return MuonScanfitCandidate(Position(0, 0, 0), dir, 0, -Inf, 0)
NDF = n_final_hits - s.est.NUMBER_OF_PARAMETERS
N = hitcount(hits)
sort!(hits)
try
estimate!(s.est, hits)
catch ex
# isa(ex, SingularSVDException) && @warn "Singular SVD"
return MuonScanfitCandidate(Position(0, 0, 0), dir, 0, -Inf, 0)
end
# TODO: consider creating a "pos()" getter for everything
update!(s.covmatrix, s.est.model.pos, hits, α)
# TODO: this is really ugly... make update!() return the view itself maybe?
V = view(s.covmatrix.M, 1:n_final_hits, 1:n_final_hits)
n_final_hits > length(s.timeresvec) && resize!(s.timeresvec, n_final_hits)
# TODO: better name for this function
timeresvec!(s.timeresvec, s.est.model, hits)
#V⁻¹ = inv(V)
Y = view(s.timeresvec, 1:n_final_hits) # only take the relevant part of the buffer
χ² = dot(Y, V \ Y) # V⁻¹ * Y == V \ Y
fit_pos = R \ s.est.model.pos
MuonScanfitCandidate(fit_pos, dir, s.est.model.t, quality(χ², N, NDF), NDF)
end