Flux: training loop logging for custom structure and model

Hello everyone
I am pretty new to using Julia for ML. I am training a custom model (Kalman Filter) that is a mutable structure. I am trying to update and log the parameters at each time step (in my case for each measurement). In my custom log function, I am only getting the final values of the parameters after all the training is done.
here is the training loop:

function run_linear_estimation(filter::KalmanFilter, opt, s0::State, action_history::AbstractArray,
    measurement_history::AbstractArray)
    """
    Run a SGD type optimisation on log-likelihood of noise covariance, with initial state estimate s0.
    """
    history = init_history(filter) # [loss, A, B, Q, H, R]
    @assert length(action_history) == length(measurement_history)
    states = [s0]
    for (u, y) in ProgressBar(zip(action_history, measurement_history))
        sp = prediction(filter, states[end], u)
        sn = correction(filter, sp, y)
        l = loss(filter, states[end], u, y)
        grads = gradient(f -> loss(f, states[end], u, y), filter)[1][]
        update!(opt, filter.R, grads[:R])
        history = log_kf_history(history, filter, l)
        push!(states, sn)
    end
    return history, states
end

here is the filter structure

mutable struct KalmanFilter{a<:AbstractMatrix, b<:AbstractMatrix, q<:Symmetric, h<:AbstractMatrix, r<:Symmetric} <: AbstractFilter
    A::a # process matrix
    B::b # control matrix
    Q::q # process zero mean noise covariance
    H::h # measurement matrix
    R::r # measurement zero mean noise covariance
end

here is the update history function, where hist (logs) is a dictionary initialised with empty Arrays

function log_kf_history(hist::Dict, filter::KalmanFilter, l::Float64)
    push!(hist["loss"], l)
    push!(hist["A"], filter.A)
    push!(hist["B"], filter.B)
    push!(hist["Q"], filter.Q)
    push!(hist["H"], filter.H)
    push!(hist["R"], filter.R)
    return hist
end

Here is the final result after training is done. As you can see the values of “R” are all the same and correspond to the last value of training. I cannot seem to push the updated values of my structure to my history dict at each time step.

Dict{String,Array{Any,1}} with 6 entries:
  "B"    => Any[[0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.0; 0.0 0.0], [0.0 0.…
  "A"    => Any[[1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0 1.0], [1.0 0.001; 0.0…
  "Q"    => Any[[0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 0.01], [0.01 0.0; 0.0 …
  "R"    => Any[[12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413], [12.8413]  …  [12.8413], [12.8413], [12.8413], […
  "loss" => Any[-27.6231, -46.0838, -1.62708, -127.982, -17.8552, -137.215, -5.49402, -12.6443, -6.14942, -1.08021  …  -8.9027, -11.1033, -205.549, -61.0984, -24.8…
  "H"    => Any[[1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0], [1.0 0.0]  …  [1.0 0.0], [1.0 0.0], [1.0 0.0], […

That is odd. Can you confirm that logging/printing the values does show them changing over time? It would be good to rule out that part as a culprit.

Yep if I put println(filter.R) just after the update! I get:

0.0%┣                                                                                                                            ┫ 1/10001 [00:05<Inf:Inf, Inf s/it]
[7.01894543038738]
[7.0283066308607465]
[7.036718801841262]
[7.045336552948723]
[7.0528250166202815]
[7.059521789006454]
[7.065886640274229]
[7.07334579385866]
[7.0812537988419315]
[7.089320721090611]
[7.096952025977223]
[7.105350526347954]
[7.113762331842448]
[7.122675380687822]
[7.131498085451385]
[7.1395030806216795]
[7.147459659652285]
[7.156184940987942]
[7.165269126788367]
[7.173688989022769]
[7.181474817524692]
[7.188590245443808]
[7.196412140541837]
[7.203959483467504]
[7.210946289678026]
[7.217375869174501]

So the values are changing… that’s not the problem here.

Any chance that R is an Array or a Ref?

If so, you are just pushing multiple references to the same state in log history. Try pushing a copy instead.

R is an array (1x1 Matrix to be exact), but its being updated at each step. Pushing a copy changes nothing.

But in the code in the OP, hist seems to contain the exact same instance of R as the one which is updated.

Example:

julia> R = [0.1]
1-element Vector{Float64}:
 0.1

julia> R_hist = []
Any[]

julia> push!(R_hist, R)
1-element Vector{Any}:
 [0.1]

julia> push!(R_hist, R)
2-element Vector{Any}:
 [0.1]
 [0.1]

julia> push!(R_hist, R)
3-element Vector{Any}:
 [0.1]
 [0.1]
 [0.1]

julia> R .= 666
1-element Vector{Float64}:
 666.0

julia> R_hist
3-element Vector{Any}:
 [666.0]
 [666.0]
 [666.0]

julia> push!(R_hist, copy(R))
4-element Vector{Any}:
 [666.0]
 [666.0]
 [666.0]
 [666.0]

julia> R .= 0.1
1-element Vector{Float64}:
 0.1

julia> R_hist
4-element Vector{Any}:
 [0.1]
 [0.1]
 [0.1]
 [666.0]

julia> map(rh -> rh === R, R_hist)
4-element Vector{Bool}:
 1
 1
 1
 0

What output do you get if you do the last operations with the R in the filter and the Rs in the history after training completes?

2 Likes

Another check would be to use deepcopy instead of copy. If that also fails (which I’m betting it will), there’s probably something going on with the logic instead. Or maybe there was just a typo and copy does work.

2 Likes

map(rh -> rh === filter.R, history["R"]) indeed returns an array of ones: Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...]

Putting copy() in the log_kf_history function seems to do the trick:

function log_kf_history(hist::Dict, filter::KalmanFilter, l::Float64)
    push!(hist["loss"], l)
    push!(hist["A"], copy(filter.A))
    push!(hist["B"], copy(filter.B))
    push!(hist["Q"], copy(filter.Q))
    push!(hist["H"], copy(filter.H))
    push!(hist["R"], copy(filter.R))
    return hist
end

Thanks for your help

2 Likes