Hello everyone
I am pretty new to using Julia for ML. I am training a custom model (Kalman Filter) that is a mutable structure. I am trying to update and log the parameters at each time step (in my case for each measurement). In my custom log function, I am only getting the final values of the parameters after all the training is done.
here is the training loop:
function run_linear_estimation(filter::KalmanFilter, opt, s0::State, action_history::AbstractArray,
measurement_history::AbstractArray)
"""
Run a SGD type optimisation on log-likelihood of noise covariance, with initial state estimate s0.
"""
history = init_history(filter) # [loss, A, B, Q, H, R]
@assert length(action_history) == length(measurement_history)
states = [s0]
for (u, y) in ProgressBar(zip(action_history, measurement_history))
sp = prediction(filter, states[end], u)
sn = correction(filter, sp, y)
l = loss(filter, states[end], u, y)
grads = gradient(f -> loss(f, states[end], u, y), filter)[1][]
update!(opt, filter.R, grads[:R])
history = log_kf_history(history, filter, l)
push!(states, sn)
end
return history, states
end
here is the filter structure
mutable struct KalmanFilter{a<:AbstractMatrix, b<:AbstractMatrix, q<:Symmetric, h<:AbstractMatrix, r<:Symmetric} <: AbstractFilter
A::a # process matrix
B::b # control matrix
Q::q # process zero mean noise covariance
H::h # measurement matrix
R::r # measurement zero mean noise covariance
end
here is the update history function, where hist (logs) is a dictionary initialised with empty Arrays
function log_kf_history(hist::Dict, filter::KalmanFilter, l::Float64)
push!(hist["loss"], l)
push!(hist["A"], filter.A)
push!(hist["B"], filter.B)
push!(hist["Q"], filter.Q)
push!(hist["H"], filter.H)
push!(hist["R"], filter.R)
return hist
end
Here is the final result after training is done. As you can see the values of “R” are all the same and correspond to the last value of training. I cannot seem to push the updated values of my structure to my history dict at each time step.
Dict{String,Array{Any,1}} with 6 entries:
"B" => Any[[0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.…
"A" => Any[[1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0…
"Q" => Any[[0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 …
"R" => Any[[12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413] … [12.8413], [12.8413], [12.8413], […
"loss" => Any[-27.6231, -46.0838, -1.62708, -127.982, -17.8552, -137.215, -5.49402, -12.6443, -6.14942, -1.08021 … -8.9027, -11.1033, -205.549, -61.0984, -24.8…
"H" => Any[[1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0] … [1.0 0.0], [1.0 0.0], [1.0 0.0], […