Extracting each iteration result in KissABC.jl

Hello everyone,

I was wondering Is it possible to extract results from each iteration in smc( ) run in kissABC package ? I was reading the documentation link but I don’t see much there.

Thank you :slight_smile:

I don’t know anything about kissABC but after I looked a bit in the code I don’t think it is possible with the current setup. Here is a link to the source code. It is plain julia so it should be okay to read KissABC.jl/src/smc.jl at master · francescoalemanno/KissABC.jl · GitHub

There are still some possible solutions,

  1. Add an issue on the GitHub repo where you request this new feature.
  2. Clone the repository and modify the package code. I think what you want is simple to add an empty array in the start of the function and then push!(results, (P = P, C = Xs, ϵ = ϵ)) in every iteration of the while loop.
  3. Make a PR to the package where you add the functionality (either as and argument or a new function reusing existing logic), without changing the behavior or performance of the original function. This is the ultimate solution.

Yes…I already did that :)… Should I make a PR and share this code?


function smc_edited(
    prior::Tprior,
    cost,
    param_name::Vector{String};
    rng::AbstractRNG = Random.GLOBAL_RNG,
    nparticles::Int = 100,
    alpha = 0.95,
    mcmc_retrys::Int = 0,
    mcmc_tol = 0.015,
    epstol = 0.0,
    r_epstol = (1 - alpha)^1.5 / 50,
    min_r_ess = alpha^2,
    max_stretch = 2.0,
    verbose::Bool = false,
    parallel::Bool = false,
) where {Tprior<:Distribution}
    min_r_ess > 0 || error("min_r_ess must be > 0.")
    mcmc_retrys >= 0 || error("mcmc_retrys must be >= 0.")
    alpha > 0 || error("alpha must be > 0.")
    r_epstol >= 0 || error("r_epstol must be >= 0")
    mcmc_tol >= 0 || error("mcmc_tol must be >= 0")
    max_stretch > 1 || error("max_stretch must be > 1")
    Np=length(prior)
    min_nparticles = ceil(
        Int,
        3 * Np / (min(alpha, min_r_ess)),
    )
    nparticles >= min_nparticles || error("nparticles must be >= $min_nparticles.")
    θs = [op(float, Particle(rand(rng, prior))) for i = 1:nparticles]
    Xs = parallel ?
        fetch.([
        Threads.@spawn cost(push_p(prior, θs[$i].x)) for i = 1:nparticles]) :
        [cost(push_p(prior, θs[i].x)) for i = 1:nparticles]

    lπs = [logpdf(prior, push_p(prior, θs[i].x)) for i = 1:nparticles]
    α = alpha
    ϵ = Inf
    alive = fill(true,nparticles)
    iteration = 0
    # Step 1 - adaptive threshold
    while true
        iteration += 1
        ϵv = ϵ
        ϵ = quantile(Xs[alive],α)
        flag=false
        if ϵ > minimum(Xs[alive])
            alive = Xs .< ϵ
        else
            alive = Xs .<= ϵ
            flag=true
        end
        ESS = sum(alive)
        verbose && @show iteration, ϵ, ESS
        # Step 2 - Resampling
        if α*ESS <= nparticles * min_r_ess
            idxalive = (1:nparticles)[alive]
            idx=repeat(idxalive,ceil(Int,nparticles/length(idxalive)))[1:nparticles]
            θs = θs[idx]
            Xs = Xs[idx]
            lπs = lπs[idx]
            ESS = nparticles
            alive .= true
        end

        # Step 3 - MCMC
        accepted = parallel ? Threads.Atomic{Int}(0) : 0
        retry_N = 1 + mcmc_retrys

        for r = 1:retry_N
                new_p = map(1:nparticles) do i
                    a = b = i
                    alive[i] || return (nothing,nothing,nothing)
                    while a==i; a = rand(rng,1:nparticles); end
                    while b==i || b==a; b = rand(rng,1:nparticles); end
                    W = op(*, op(-, θs[b], θs[a]), max_stretch*randn(rng)/sqrt(Np))
                    (log(rand(rng)), op(+, θs[i], W), 0.0)
                end
                @cthreads parallel for i = 1:nparticles # non-ideal parallelism
                    alive[i] || continue
                    lprob, θp, logcorr = new_p[i]
                    isnothing(lprob) && continue
                    lπp = logpdf(prior, push_p(prior, θp.x))
                    lπp < 0 && (!isfinite(lπp)) && continue
                    lM = min(lπp - lπs[i] + logcorr, 0.0)
                    if lprob < lM 
                        Xp = cost(push_p(prior, θp.x))
                       
                        if flag
                            Xp > ϵ && continue
                        else
                            Xp >= ϵ && continue
                        end
                        θs[i] = θp
                        Xs[i] = Xp
                        lπs[i] = lπp
                        if parallel 
                            Threads.atomic_add!(accepted, 1)
                        else
                            accepted += 1
                        end
                    end
                end
            accepted[] >= mcmc_tol * nparticles && break
        end
        if 2*abs(ϵv - ϵ) < r_epstol * (abs(ϵv)+abs(ϵ)) ||
           ϵ <= epstol ||
           accepted[] < mcmc_tol * nparticles
           break
        end

        As = [push_p(prior, θs[i].x) for i = 1:nparticles][alive]

        l = length(prior)
        Q = map(x -> Particles(x), getindex.(As, i) for i = 1:l)
        length(Q)==1 && (Q=first(Q))
    
        @info "Saving Population $(iteration)"
        save_param!(DataFrame(Array(Q), param_name), Xs, iteration)
        @info "Current Particles info - $(Q)"
        @info "Done"
        
    end
    θs = [push_p(prior, θs[i].x) for i = 1:nparticles][alive]

    l = length(prior)
    P = map(x -> Particles(x), getindex.(θs, i) for i = 1:l)
    length(P)==1 && (P=first(P))

    @info "Saving Population $(iteration) - Final"
    save_param!(DataFrame(Array(P), param_name), Xs, iteration)
    @info "Final Particles info after $(iteration) - $(P)"
    @info "Process Finished"

    (P = P, C = Xs, ϵ = ϵ)
end

1 Like

Awesome,

Make an issue on the repository where you propose the feature and add a link to this discussion. Then the maintainers can come with their input.

I think the functionality could benefit other users.

1 Like

Is there a way to obtain weights for accepted particles. I don’t think its providing automatically at end results. Also, Im not sure where exactly its calculating weight for accepted particles. Best guess is W = op(*, op(-, θs[b], θs[a]), max_stretch*randn(rng)/sqrt(Np)). Is thats the case?