Kalman Filter with Enzyme

I think I’ve figured it out, although I’m not entirely positive I understand the solution.

After using Cthulhu.jl for further type inference, I noticed an instability with the following block:

particle, residual, S = GaussianDistributions.correct(
    particle, Gaussian([obs], R), B
)

I guess this makes an unstable call to Base.indexed_iterate, which is what raises the error mentioned above. Although, I cannot say I fully understand why the stacktrace points to LinearAlgebra.triu!(...) when differentiating via Enzyme.jl.

Additionally, the program will raise another error stemming from GaussianDistributions.jl log likelihood computation. This can be rectified by instead using Distributions.jl and converting Gaussian to MvNormal.

My updated filter is as follows:

function kalman_filter(
        model::LinearGaussianModel,
        init_state::Gaussian,
        observations::Vector{T}
    ) where {T<:Real}
    log_evidence = zero(T)
    state = init_state

    A = model.transition.ϕ
    Q = model.transition.Σ

    B = model.observation.ϕ
    R = model.observation.Σ

    for obs in observations
        proposed = Gaussian(A*state.μ, A*state.Σ*A' + Q)

        # internals of GaussianDistributions.correct(...)
        residual = [obs] - B*proposed.μ
        S = Symmetric(B*proposed.Σ*B' + R)
        gain = proposed.Σ*B' / S

        # apply correction
        state = Gaussian(
            proposed.μ + gain*residual,
            (I - gain*B)*proposed.Σ*(I - gain*B)' + gain*R*gain'
        )

        # use a different calculation for the logpdf
        log_evidence += Distributions.logpdf(
            MvNormal(zero(residual), S),
            residual
        )
    end

    return log_evidence
end

When I did the calculations previously, I calculated the Kalman gain using the Cholesky decomposed innovations covariance. While this is a more stable computation, Enzyme.jl struggled to differentiate the right division with a Cholesky. @wsmoses if this is something useful, I would be more than happy to make a MWE and open an issue.

1 Like