Optimizing matrix multiplication code and being compatible with autodiff

I am trying to optimize some matrix multiplications inside a Turing model (a Kalman filter) but since I am depending on autodiff I can’t pre-allocate matrices and mutate things. How does one go about optimizing this type of code?

Below is a working snippet of the (unoptimized) model code and the code used to generate flame graphs. I have also commented the lines where the profiler collects the most samples, which reveal a bottle neck in kalman_update.


using Turing, Random, LinearAlgebra
D = 35

x0 = randn(D)           # init state
P0 = diagm(ones(D))     # init state covariance

F = diagm(ones(D))      # transition 
Q = 0.1*diagm(ones(D))  # transition covariance

H = ones(1, D)          # observation
R = 0.1*diagm(ones(1))  # observation covariance

xs = Vector{Array{Float64}}(undef, 10)
xs[1] = x0
for i in 2:length(xs)
    xs[i] = F*xs[i-1]
end

ys = Matrix([only(H*xs[i] .+ randn(1)*R) for i in eachindex(xs)]')

function kalman_predict(x, P, H, F, Q, R)
    x = F*x
    P = F*P*F' + Q # 9% samples here
    y = H*x
    S = H*P*H' + R
    x, P, y, S
end

function kalman_update(x, P, r, S, H, R)
    K = P*H'/S
    x = x + K*r
    I_KH = I - K*H
    P = I_KH*P*I_KH' + K*R*K' # 88% samples here
    y = H*x
    S = H*P*H' + R
    x, P, y, S
end

@model function kalman_model(ys, H, F, Q)
    obs_dim, N = size(ys)
    latent_dim = size(H, 2)
    x₀ ~ MvNormal(zeros(latent_dim), I)
    ϵ ~ filldist(Gamma(1, 1), latent_dim)
    σ ~ filldist(Gamma(1, 1), obs_dim)

    P = diagm(ϵ.^2)
    R = diagm(σ.^2)
    x = x₀
    for t in 1:N
        x, P, y, S = kalman_predict(x, P, H, F, Q, R) # 9% samples in kalman_predict

        if t <= N
            r = ys[:,t] - y
            x, P, y, S = kalman_update(x, P, r, S, H, R) # 90% samples in kalman_update
            Turing.@addlogprob! - 0.5 * sum(logdet(S) + r'*inv(S)*r)
        end
    end
end

Random.seed!(12345)
model = kalman_model(ys, H, F, Q,)

function benchmark_model(model, n)
    for i in 1:n
        sample(model, NUTS(), 1)
    end
end
@profview benchmark_model(model, 1)
@profview benchmark_model(model, 10)
4 Likes

Just bumping this to say I’m very interested in the answer!

You can mutate with ReverseFiff, no? I’ve always found reversediff faster than Zygote, when used with Turing, atleast with NUTS. But I’m sure that’s very problem specific.

From what I read here, it seems ReverseDiff only allows mutation on arrays w.r.t. which you don’t need to differentiate

If you use AD long enough, you eventually have to learn to write your own vector–Jacobian products (a.k.a. rrule in ChainRulesCore.jl) to handle pieces of code that AD either cannot handle (because it does mutation, calls external libraries, or …) or performs poorly on (typically because you have some approximate iterative solver, but AD doesn’t know the underlying problem you are solving).

The good news is that you typically only need to do this for isolated components of your code, and then the AD system can put everything else together for you.

2 Likes

I see. My experience with Turing models has been that I can allocate an Array{T} and update its elements each iteration, with reversediff. But the same model fails with zygote.

@stevengj I suppose that is inevitable, yeah.

This thread got linked in the Julia Slack , and since it will disappear into the slack hole soon enough I am copying the main points here.

There is a more efficient implementation of Kalman filtering (compared to the one in the OP) in DifferenceEquations.jl/kalman_likelihood.jl at main · SciML/DifferenceEquations.jl · GitHub which is AD compatible. It treats Kalman filtering as solving a linear state space problem and produces all states in one go. In other words, it is less granular than the implementation in the OP as it doesn’t let the user predict and then filter independently (please correct me if I got this wrong about this).

Additionally, GitHub - mschauer/Kalman.jl: Flexible filtering and smoothing in Julia exists. This implementation is very similar to the one in the OP and is already AD compatible and about as efficient.

A paper on differentiable linear algebra operators (it’s the one for the mxnet framework) was linked [1710.08717] Auto-Differentiating Linear Algebra which contains rules that are relevant for Kalman filters.

Generally, people in the thread agree that it would be nice to have an AD compatible and efficient Kalman implementation for use cases like these. It is not clear where it should live though. One way forward is to implement more efficient (and mutating) Kalman filtering, and the necessary adjoints to make it AD compatible, and contribute them to Kalman.jl.

1 Like