Selecting a specific cycle to perform operation during a callback

Hi all,

Below is a simple example of me performing a Kalman filtration on 3 periodic cycles of a model. We generate data for all 3 cycles and perform a Kalman filtration forcing the time step and utilising the integrator interface.

My question is: If i wanted to adapt this code and callback to only perform the Kalman filtration on say cycle (1.0, 2.0) how would I adapt the below?

My main thinking for this is models which exhibit transient behavior I need to run them to steady state before I extract 1 cycle which I perform the KF on. (Assuming I have data for one corresponding cycle also).

I know how to extract 1 cycle for everything outside of the callback :slight_smile:

Thanks in advance :slight_smile:

## Implementation using the call back to implement the UKF 
using Distributions, Plots, DifferentialEquations, LinearAlgebra, OffsetArrays
N = 30001 # number of observations 
# Time span 
tspan = range(0.0,3.0,N)
# Known system solution
P1_(t) = sin(2*π*t)
R1_ = 1.0
R2_ = 1.0
R3_ = 2.0
Pb_ref(t) = 2/5 * sin(2*π*t)
Q1_ref(t) = 3/5 * sin(2*π*t)
Q2_ref(t) = 2/5 * sin(2*π*t)
Q3_ref(t) = 1/5 * sin(2*π*t)

function splitter!(du, u, p, t)
    R1, R2, R3 = p
    Pb, Q1, Q2, Q3 = u
    du[1] = 2π*R2*R3*cos(2π*t)/(R1*R3+R1*R2+R2*R3) #Pb
    du[2] = 2π*(R2+R3)*cos(2π*t)/(R1*R3+R1*R2+R2*R3) #Q1
    du[3] = 2π*R3*cos(2π*t)/(R1*R3+R1*R2+R2*R3) #Q2
    du[4] = 2π*R2*cos(2π*t)/(R1*R3+R1*R2+R2*R3) #Q3
    nothing 
end

u0 = zeros(4)
tspan = (0.0, 3.0)
p = [1.0, 1.0, 2.0]

prob = ODEProblem(splitter!,u0,tspan,p)
t = LinRange(0.0,3.0,N)
sol = solve(prob, Tsit5(), adaptive = false, dt = 0.0001, reltol=1e-12, abstol=1e-12)

### UKF implementation ###


# Create the observations with errors

tspan = range(0.0,3.0,N)
Obs = [Pb_ref.(tspan) Q2_ref.(tspan) Q3_ref.(tspan)]
#Obs = [sol[1,:] sol[3,:] sol[4,:]]
noise = zeros(N,3)
Nobs = zeros(N,3)
ϵ = Normal(0.0,0.05)
for i in 1:3
    for k in 1:N
    noise[k,i] = rand(ϵ,1)[1]
    Nobs[k,i]=Obs[k,i]*(1+noise[k,i])
    end
end
Nobs = Array(transpose(Nobs))

## Constants needed for unsented method 
global const L  = 7 #dime~nsion of state vector 
global const m = 3 #measurment number 
global const k = 0.0 # 
global const α = 1*10^(-1) # determines spread of sigma points 
global const β = 2.0
global const λ = α^2*(L+k) - L
# Define Augmented Covariance matrix 
σ = 1.0
R=(σ^2)*Matrix{Float64}(I,3,3)
# 4 States Pb Q1 Q2 Q3 and then parameters R1 R2 R3
#χ = zeros(L,2L +1,N)
χ = zeros(L,2L +1,N+1)
χ = OffsetArray(χ, 1:L, 1:2L+1, 0:N)
χf =  zeros(L,2L +1,N)
Wm = zeros(2L+1)
Wc = zeros(2L+1)
Xμ = zeros(L,N)
Xa = zeros(L,N)
PX = zeros(L,L,N)
PXA = zeros(L,L,N)
Y = zeros(m,2L +1,N)
Yμ = zeros(m,N)
PY = zeros(m,m,N)
PχY = zeros(L,m,N)
K = zeros(L,m,N)

# initalise some of the values 
PXA[:,:,1] = diagm([0.1,0.1,0.1,0.1,0.01,0.3,0.6]) #first cov matrix 
Xa[:,1] = [sol[1]; rand(Normal(1.0, 0.01), 1)[1]; rand(Normal(1.0, 0.3), 1)[1] ;rand(Normal(2.0, 0.6), 1)[1]]

PX[:,:,1] = PXA[:,:,1]
PY[:,:,1] = diagm([0.1,0.1,0.1])
# Weights 
Wm[1] = λ/(L+λ)
Wc[1] = λ/(L+λ) + (1 - α^2 + β)

@inbounds for i in 2:2L+1
    Wm[i] = 0.5/(L+λ)
    Wc[i] = 0.5/(L+λ)
end 

χ[:,1, 0] = Xa[:,1]

Lmatrix = LinearAlgebra.cholesky((PXA[:,:,1] + (1e-10)*Matrix{Float64}(I,L,L)) |> LinearAlgebra.Hermitian).L

@inbounds for (t, j) in zip(2:L+1, 1:L)
    χ[:, t, 0] = Xa[:,1] .+ sqrt((L + λ)) * Lmatrix[:, j]  
end

@inbounds for (t, j) in zip(L+2:2L+1, 1:L)
    χ[:, t, 0] = Xa[:,1] .- sqrt((L + λ)) * Lmatrix[:, j] 
end

function splitter!(du, u, p, t)
    R1, R2, R3 = p
    Pb, Q1, Q2, Q3 = u
    du[1] = 2π*R2*R3*cos(2π*t)/(R1*R3+R1*R2+R2*R3) #Pb
    du[2] = 2π*(R2+R3)*cos(2π*t)/(R1*R3+R1*R2+R2*R3) #Q1
    du[3] = 2π*R3*cos(2π*t)/(R1*R3+R1*R2+R2*R3) #Q2
    du[4] = 2π*R2*cos(2π*t)/(R1*R3+R1*R2+R2*R3) #Q3
    nothing 
end

u0 = zeros(4)
tspan = (0.0, 2.0)
p = [Xa[5,1],Xa[6,1],Xa[7,1]]

condition(u,t,integrator) = integrator.iter % 1 == 0

function affect!(integrator)
    i  = integrator.iter
    u_ = integrator.u
    p_ = integrator.p
    Lmatrix = LinearAlgebra.cholesky((PXA[:,:,i] + (1e-10)*Matrix{Float64}(I,L,L)) |> LinearAlgebra.Hermitian).L

    Xa[:,i] = [u_;p_]
    χ[:,1, i] = Xa[:,i]

    @inbounds for (t, j) in zip(2:L+1, 1:L)
        χ[:, t, i] = Xa[:,i] .+ sqrt((L + λ)) * Lmatrix[:, j]  
    end

    @inbounds for (t, j) in zip(L+2:2L+1, 1:L)
        χ[:, t, i] = Xa[:,i] .- sqrt((L + λ)) * Lmatrix[:, j] 
    end

    for j in 1:2L+1
        a = ODEProblem(splitter!, χ[1:4,j,i-1], (integrator.t - integrator.dt, integrator.t),χ[5:7,j,i-1]);
        res = DifferentialEquations.solve(a, Tsit5(), adaptive = false, dt = 0.0001, reltol=1e-12, abstol=1e-12, saveat = [integrator.t]) 
        χf[:,j,i+1] = [res.u[1];χ[5:7,j,i]]
    end 


    a = zeros(L,2L+1)
    for j in 1:2L+1
        a[:,j] = Wm[j]*χf[:,j,i+1]
        Xμ[:,i+1] = sum(eachcol(a))
    end

    b = zeros(L,L,2L+1)
    for j in 1:2L+1
        b[:,:,j] = Wc[j]*(χf[:,j,i+1] - Xμ[:,i+1])*transpose(χf[:,j,i+1] - Xμ[:,i+1])
        PX[:,:,i+1] = sum(b, dims = 3)
    end

    for j in 1:2L+1
        Y[:,j,i+1] = χf[[1,3,4],j,i+1]
    end

    d = zeros(m,2L+1)
    for j in 1:2L+1
        d[:,j] = Wm[j]*Y[:,j,i+1]
        Yμ[:,i+1] = sum(eachcol(d))
    end

    e = zeros(m,m,2L+1)
    for j in 1:2L+1
        e[:,:,j] = Wc[j]*(Y[:,j,i+1] - Yμ[:,i+1])*transpose(Y[:,j,i+1] - Yμ[:,i+1]) 
        PY[:,:,i+1] = sum(e, dims = 3) + R
    end
 
    f = zeros(L,m,2L+1)
    for j in 1:2L+1
        f[:,:,j] = Wc[j]*(χf[:,j,i+1] - Xμ[:,i+1])*transpose(Y[:,j,i+1] - Yμ[:,i+1])
        PχY[:,:,i+1] = sum(f, dims = 3)
    end

    # Now assimilate the thing 

    K[:,:,i+1] = PχY[:,:,i+1]*inv(PY[:,:,i+1])

    Xa[:,i+1] = Xμ[:,i+1] + K[:,:,i+1]*(Nobs[:,i+1] - Yμ[:,i+1])

    PXA[:,:,i+1] =  PX[:,:,i+1] - K[:,:,i+1]*PY[:,:,i+1]*transpose(K[:,:,i+1])

    integrator.u = Xa[1:4,i+1]
    integrator.p = Xa[5:7,i+1]

end
save_positions = (false,false)

cb = DiscreteCallback(condition, affect!, save_positions=save_positions)

tspan = (0.0, 3.0)
prob = ODEProblem(splitter!,u0,tspan,p)
sol = solve(prob, Tsit5(), adaptive = false, dt = 0.0001, reltol=1e-12, abstol=1e-12, callback = cb)

Just change some boolean in the parameters from true to false?