Fastest way to fetch an index

if the values of the vector data.T are integers or β€œcountable” (in which case you have to find some transformation that makes them integers) you can build such an index once and for all without needing to do searches to find the last value equal to your key from time to time.

l,h=Int.(extrema(data.T))
idx=zeros(Int64,h-l+1)
for (i,v) in enumerate(data.T)
    idx[Int(v)]=i
end

There are, but the length of those mapping vectors will be huge. But maybe that is the way to go, i’ll try this tomorow.

certainly not critical to the performance of the script, but know that there is a function in Base that does exactly this: diff(grid)==grid[2:end] .- grid[1:end-1]

βˆ‚t  = [(grid[2:end] .- grid[1:end-1])...,0.0]

while for the performance part it is better to avoid the splat and instead use push!() to add an element.

Just to evaluate the asymptotic improvement.
We start from this:

julia>     @btime Ξ›(data, table)
  464.245 ms (57 allocations: 1.08 MiB)

and assuming we have a magic lambda function (which is the function most involved)

 #Ξ»β‚š           = Ξ»(table, data.age[i] + grid[j], data.year[i] + grid[j], data.sex[i])
                Ξ»β‚š           = table.values[1,1,1]

we can tend to:

julia>     @btime Ξ›(data, table)
  89.540 ms (57 allocations: 1.08 MiB)
1 Like

I propose as a way to bypass the intervention of the lambda function, the construction of an idx_ages index that can be used within the nested cycles, obtaining a time of 290ms versus 430ms.
Similar thing (with a little more patience) could be done for idx_years (perhaps a further decrease of another 100ms can be obtained)

IDX age

function Ξ›(data, table, prec=1.0)
        # Initialize vectors: 
        grid         = unique(sort([(1:prec:(maximum(data.T)+1)); data.T; [maximum(data.T)+1]]))
        num_excess   = zero(grid)
        num_pop      = zero(grid)
        num_variance = zero(grid)
        den          = zero(grid)

        d=Int(maximum(data.age)+maximum(grid))
        idxa=fill(lastindex(table.ages),d)
        s=1
        for i in  1:Int(table.ages[end])-1
            if i >= table.ages[s+1]
                s+=1
            end 
            idxa[i]=s
        end

        #table.years.-table.years[1]
        #data.year.-table.years[1]
                # Loop over individuals
        for i in 1:length(data.age)
            Tα΅’ = searchsortedlast(grid, data.T[i]) # index of the time of event (death or censored) in the grid
            wβ‚š = 1.0
            sΞ›β‚š = 0.0
            for j in 1:Tα΅’
                #Ξ»β‚š           = Ξ»(table, data.age[i] + grid[j], data.year[i] + grid[j], data.sex[i])
                a= idxa[Int(data.age[i] + grid[j])]
                y=searchsortedlast(table.years, data.year[i] + grid[j])
                g = data.sex[i]==:male ? 1 : 2
                Ξ»β‚š           = table.values[a,y,g]
                Ξ›β‚š           = Ξ»β‚š * (grid[j+1]-grid[j]) # Ξ»β‚š * βˆ‚t 
                sΞ›β‚š         += Ξ›β‚š
                wβ‚š           = exp(sΞ›β‚š)
                num_pop[j] += Ξ›β‚š * wβ‚š
                den[j]     += wβ‚š
            end
            num_excess[Tα΅’]   += wβ‚š * data.status[i]
            num_variance[Tα΅’] += wβ‚š^2 * data.status[i]
        end
        βˆ‚Ξ›β‚‘ = (num_excess - num_pop) ./ den
        βˆ‚Οƒβ‚‘ = num_variance ./ den.^2
        return grid, βˆ‚Ξ›β‚‘, βˆ‚Οƒβ‚‘
    end

@rocco_sprmnt21 And the precomputation could be done at construction of the table object (the dump I used to construct the MWE did not include structures but I do already have one) and thus be completely offloaded (since the \Lambda function is likely running many times with the same ratetable).

This is a very good idea. Is’nt the length of the obtained indexing vector very large ?

Why do you think more care should be used for the years obejcts, wont the same code directly work ?

The idea can also be applied to year vectors, but given the range of values they have, I think it is preferable to carry out translations to bring all the values between 1 and max(yera)-min(year).
Another thing to do in preparation, which perhaps could help, is to work on integers for data that are integers (since they then have to be used as indexes).
As soon as I have time I’ll try to draft a more complete script, if it’s not clear what I’m trying to say.

Integerify
table = (
        ages = Int.(trunc.(table.ages .* 365.241)),
        years = Int.(trunc.(table.years .* 365.241)),
        sexes = table.sexes,
        values = table.values
    )
    data=(T=Int.(data.T),age=Int.(data.age),sex=data.sex,status=data.status,year=Int.(data.year)
    )


    function Ξ›(data, table, prec=1)
        # Initialize vectors: 
        grid         = unique(sort([(1:prec:(maximum(data.T)+1)); data.T; [maximum(data.T)+1]]))
        num_excess   = zeros(grid[end])
        num_pop      = zeros(grid[end])
        num_variance = zeros(grid[end])
        den          = zeros(grid[end])

        d=Int(maximum(data.age)+maximum(grid))
        idxa=fill(lastindex(table.ages),d)
        s=1
        for i in  1:table.ages[end]-1
            if i >= table.ages[s+1]
                s+=1
            end 
            idxa[i]=s
        end

        #table.years.-table.years[1]
        #data.year.-table.years[1]
                # Loop over individuals
        for i in 1:length(data.age)
            Tα΅’ = searchsortedlast(grid, data.T[i]) # index of the time of event (death or censored) in the grid
            wβ‚š = 1.0
            sΞ›β‚š = 0.0
            for j in 1:Tα΅’
                #Ξ»β‚š           = Ξ»(table, data.age[i] + grid[j], data.year[i] + grid[j], data.sex[i])
                a= idxa[data.age[i] + grid[j]]
                y=searchsortedlast(table.years, data.year[i] + grid[j])
                g = data.sex[i]==:male ? 1 : 2
                Ξ»β‚š           = table.values[a,y,g]
                Ξ›β‚š           = Ξ»β‚š * (grid[j+1]-grid[j]) # Ξ»β‚š * βˆ‚t 
                sΞ›β‚š         += Ξ›β‚š
                wβ‚š           = exp(sΞ›β‚š)
                num_pop[j] += Ξ›β‚š * wβ‚š
                den[j]     += wβ‚š
            end
            num_excess[Tα΅’]   += wβ‚š * data.status[i]
            num_variance[Tα΅’] += wβ‚š^2 * data.status[i]
        end
        βˆ‚Ξ›β‚‘ = (num_excess - num_pop) ./ den
        βˆ‚Οƒβ‚‘ = num_variance ./ den.^2
        return grid, βˆ‚Ξ›β‚‘, βˆ‚Οƒβ‚‘
    end 
julia>     using BenchmarkTools

julia>     @btime Ξ›(data, table)
  191.024 ms (59 allocations: 1.40 MiB)
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10  …  8140, 8141, 8142, 8143, 8144, 8145, 8146, 8147, 8148, 8149], [0.002894868611732328, 0.0018966845282910414, 0.003079845040555203, 0.003427879530253135, 0.002763112499157984, 0.002771085990264049, 0.0015876394324403375, 0.004837336707644467, 0.003830746874088555, 0.0024714042541990793  …  -0.00045476929384445137, -0.00045480172457826204, -0.00045483414608312107, -0.000542010162781309, -0.000542010162781309, -0.000542010162781309, -0.000542010162781309, -0.000542010162781309, -0.000542010162781309, NaN], [5.049982192832294e-7, 3.388282760556044e-7, 5.387125920341493e-7, 5.990895522239727e-7, 4.886001588297959e-7, 4.912342704415929e-7, 2.9113476790888077e-7, 8.468796858477259e-7, 6.779046685495679e-7, 4.4688733619822384e-7  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, NaN])

adoro questo tipo di grafici!!!

julia>        lineplot(grid,cumprod(1 .- βˆ‚Ξ›β‚‘))
       β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” 
   1.2 β”‚β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β”‚
       β”‚β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β£Έβ €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β”‚
       β”‚β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β£Ώβ €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β”‚
       │⑆⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       │⑇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ⒸⒹ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       │⑇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ⒸⒸ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       β”‚β’³β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β’Έβ’Έβ €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β”‚
       β”‚β ˜β‘†β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β’Έβ’Έβ €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β”‚
       │⠀⒳⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⑇Ⓒ⠀⒰Ⓕ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       β”‚β €β ˆβ£†β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β‘‡β’Έβ €β‘β’Έβ €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β”‚
       β”‚β €β €β ˜β’¦β‘€β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β‘‡β’Έβ‘Όβ β’Έβ €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β”‚
       │⠀⠀⠀⠀⠳⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⒰⠃Ⓒ⑇⠀Ⓒ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
       β”‚β €β €β €β €β €β ˆβ ™β ¦β’€β£€β£€β’€β‘€β €β €β €β£€β£ β‘΄β žβ €β ˜β β €β’Έβ£€β£€β €β’€β£€β£€β €β €β €β €β €β €β €β €β €β”‚
       β”‚β €β €β €β €β €β €β €β €β €β €β ˆβ ‰β ‰β ‰β ‰β ‰β β €β €β €β €β €β €β €β €β €β ˜β ’β ‹β ‰β Ήβ‘¦β Άβ Άβ’²β£ β †β €β €β €β”‚
   0.3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠁⠀⠀⠀⠀│
       β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       β €0β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €β €9 000β €
1 Like

Grazie mille :slight_smile:


Edit: @rocco_sprmnt21 based on your proposal I drafted this class to avoid clutturing my main function:

struct FastIndexFetchingVector
    vec::Vector{Int64}
    idx::Vector{Int64}
    m::Int64
    function FastIndexFetchingVector(vec)
        @assert eltype(vec)<:Integer
        @assert issorted(vec)
        m, M = vec[1]-1, vec[end] # minimum and maximum integer values. 
        idx = zeros(Int, M-m)
        s = 1
        for i in 1:length(idx)
            if m + i > vec[s+1]
                s+=1
            end 
            idx[i]=s
        end
        return new(vec,idx,m)
    end
end
function searchsortedlast(v::FastIndexFetchingVector, x)
    i = clamp(Int(trunc(x-v.m)),1,length(v.idx))
    return v.idx[i]
end

Final runtime got from 0.6s on my laptop to 0.2 so about a 3x improvement :slight_smile: Thank you very much :slight_smile:

Now my flame graph looks like this:

You see that the \lambda function moved from about 2/3 of the runtime to about 1/5 :slight_smile: The computation is now dominated by the only exp() in the inner loop… for which I cannot do much more :frowning:

Thansk all for the wanderfull ideas !

How accurate do you need exp to be?

For example, you might approximate exponential using a power series \exp(x) \approx 1 + x ( 1 + \frac{x}{2} ( 1 + \frac{x}{3} )), which is reasonably accurate for a small x (see below).

If I understand your MWE correctly, you are computing \exp( \sum_{i=1}^n x_i) for an incrementally-increasing n, so the arguments of the exponential grow (making the above approximation worse). However, since \exp( \sum_{i=1}^n x_i) = \prod_{i=1}^n \exp(x_i) (where each x_i is small), you can evaluate these products recursively r_n = \begin{cases} 1 & \text{if } n = 0 \\ r_{n-1} \exp(x_n) & \text{otherwise}\end{cases}
(Note that you only need to keep the previous r_{n-1}, not all of them.)

Overall, you might have something like

myexp(x) = 1 + x * (1 + x / 2 * (1 + x / 3))

xs = [0.001, 0.002, 0.003]
r = 1.0
for (n, x) in enumerate(xs)
    r = r * myexp(x)
    # do something with r (r = exp(sum_{i=1}^n x_i)
    @show r exp(sum(xs[1:n]))
end

# Output is
# r = 1.0010005001666666
# exp(sum(values[1:n])) = 1.0010005001667084

# r = 1.0030045045026676
# exp(sum(values[1:n])) = 1.003004504503377

# r = 1.0060180360499662
# exp(sum(values[1:n])) = 1.006018036054065

On my computer, I am getting:

julia> @btime myexp(x) setup=x=rand()
  3.773 ns (0 allocations: 0 bytes)
1.5760759809499558

julia> @btime exp(x) setup=x=rand()
  7.313 ns (0 allocations: 0 bytes)

So you could achieve further speed-up.

2 Likes

@barucden Indeed this is a possibility, but in this case it’s not really possible :

Note that this is suboptimal because the compiler will do a floating-point division for x / 3, since this is not exactly the same as x * (1/3) in floating-point arithmetic. You can use myexp(x) = @fastmath 1 + x * (1 + x / 2 * (1 + x / 3)) to tell it that it is allowed to convert x / 3 into x * 0.333... and also to use fused multiply-add instructions, even though this slightly changes the roundoff errors. On my computer this is about a 20% speedup.

3 Likes