Kalman Filter with Enzyme

Consider the simple problem of MLE on a linear Gaussian state space model using a Kalman filter. When properly designed, automatic differentiation should be a trivial task for achieving a log score; and for most backends, it is.

Although my code seems type stable, Enzyme.jl refuses to compute the gradients with strict aliasing turned on. This is also compounded by some BLAS warnings noting further performance degradation. Meanwhile, ForwardDiff.jl and Zygote.jl efficiently produce gradients without a hitch.

Without setting Enzyme.API.strictAliasing!(false), I get the following error:

ERROR: Enzyme compilation failed due to illegal type analysis.
 This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].
 Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)

Caused by:
Stacktrace:
 [1] triu!
   @ ~/.julia/juliaup/julia-1.11.0+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:441
 [2] getproperty
   @ ~/.julia/juliaup/julia-1.11.0+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/qr.jl:497

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/errors.jl:319
  [2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/errors.jl:210
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/DL6fx/src/api.jl:268
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:1690
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:4534
  [6] codegen
    @ ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:3337 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5386
  [8] _thunk
    @ ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5386 [inlined]
  [9] cached_compilation
    @ ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5438 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5549
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5734
 [12] autodiff
    @ ~/.julia/packages/Enzyme/DL6fx/src/Enzyme.jl:485 [inlined]
 [13] autodiff
    @ ~/.julia/packages/Enzyme/DL6fx/src/Enzyme.jl:524 [inlined]
 [14] macro expansion
    @ ~/.julia/packages/Enzyme/DL6fx/src/sugar.jl:275 [inlined]
 [15] gradient(rm::ReverseMode{…}, f::typeof(logℓ), x::Vector{…}, args::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/DL6fx/src/sugar.jl:263
 [16] top-level scope
    @ ~/code/state_space_models/SSMProblems-AD/examples/data-generating-process/mwe.jl:97
Some type information was truncated. Use `show(err)` to see complete types.

When strict aliasing is turned off, I get proper gradients, but at a snails pace. This may be exacerbated by the following BLAS warning in both instances:

┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/GnbhK/src/utils.jl:59

Although, this warning seems harmless since I’m not running this in parallel.

For replication, my (not-so) minimal working example is as follows:

using LinearAlgebra
using Enzyme
using GaussianDistributions
using Random

## MODEL DEFINITION ###########################################################

struct LinearGaussianProcess{
        T<:Real,
        ΦT<:AbstractMatrix{T},
        ΣT<:AbstractMatrix{T}
    }
    ϕ::ΦT
    Σ::ΣT
    function LinearGaussianProcess(ϕ::ΦT, Σ::ΣT) where {
            T<:Real,
            ΦT<:AbstractMatrix{T},
            ΣT<:AbstractMatrix{T}
        }
        @assert size(ϕ,1) == size(Σ,1) == size(Σ,2)
        return new{T, ΦT, ΣT}(ϕ, Σ)
    end
end

struct LinearGaussianModel{
        ΘT<:Real,
        TT<:LinearGaussianProcess{ΘT},
        OT<:LinearGaussianProcess{ΘT}
    }
    transition::TT
    observation::OT
    dims::Tuple{Int, Int}
end

## KALMAN FILTER ##############################################################

function kalman_filter(
        model::LinearGaussianModel,
        init_state::Gaussian,
        observations::Vector{T}
    ) where {T<:Real}
    log_evidence = zero(T)
    particle = init_state

    A = model.transition.ϕ
    Q = model.transition.Σ

    B = model.observation.ϕ
    R = model.observation.Σ

    for obs in observations
        particle = let μ = particle.μ, Σ = particle.Σ
            Gaussian(A*μ, A*Σ*A' + Q)
        end

        particle, residual, S = GaussianDistributions.correct(
            particle,
            Gaussian([obs], R), B
        )

        log_evidence += GaussianDistributions.logpdf(
            Gaussian(zero(residual), Symmetric(S)),
            residual
        )
    end

    return log_evidence
end

## DEMONSTRATION ##############################################################

# θ should be a single element vector for this demonstration
function build_model(θ::Vector{T}) where {T}
    trans = LinearGaussianProcess(T[1;;], Diagonal(θ))
    obs   = LinearGaussianProcess(T[1;;], Diagonal(T[1]))
    return LinearGaussianModel(trans, obs, (1,1))
end

# log likelihood function
function logℓ(θ::Vector{T}, data) where {T<:Real}
    model = build_model(θ)
    init_state = Gaussian(zeros(T, 1), Diagonal(ones(T, 1)))
    return kalman_filter(model, init_state, data)
end

# data generation (with unit covariance)
rng  = MersenneTwister(1234)
data = cumsum(randn(rng, 100)) .+ randn(rng, 100)

# ensure that log likelihood looks stable
logℓ([1.0], data)

# calculate gradients
gradient(Reverse, logℓ, [1.0], Const(data))

I use GaussianDistributions.jl for the sake of brevity, but the same issues are present when I do the algebra myself.

Any help is appreciated. I am baffled.

2 Likes

So I dived into this and it is indeed type unstable.

In particular consider:


    init_state = Gaussian(zeros(T, 1), Diagonal(ones(T, 1))) # Gaussian{Vector{Float64}, Diagonal{...}}

    particle = init_state. # Gaussian{Vector{Float64}, Diagonal{...}}
    for obs in observations

        particle = let μ = particle.μ, Σ = particle.Σ
            Gaussian(A*μ, A*Σ*A' + Q)
        end
        # Gaussian{Vector{Float64}, Matrix{Float64}}

        particle, residual, S = GaussianDistributions.correct(
            particle,
            Gaussian([obs], R), B
        )

        log_evidence += GaussianDistributions.logpdf(
            Gaussian(zero(residual), Symmetric(S)),
            residual
        )
    end

In other words the particle variable is type unstable. On the first iteration it is a Gaussian of Diagonal, on all subsequent iterations it is a Gaussian of a Matrix.

So the error message by Enzyme is warranted (and fixing will likely speed up your code).

5 Likes

Good spot with the type instability, unfortunately though it hasn’t solved the issue. I updated the initialization as follows:

function logℓ(θ::Vector{T}, data) where {T<:Real}
    model = build_model(θ)
    init_state = Gaussian(zeros(T, 1), Matrix(Diagonal(ones(T, 1))))
    return kalman_filter(model, init_state, data)
end

Which fixes the type instability according to @code_warntype (I neglected to check this earlier); however, the same error message appears.

I think I’ve figured it out, although I’m not entirely positive I understand the solution.

After using Cthulhu.jl for further type inference, I noticed an instability with the following block:

particle, residual, S = GaussianDistributions.correct(
    particle, Gaussian([obs], R), B
)

I guess this makes an unstable call to Base.indexed_iterate, which is what raises the error mentioned above. Although, I cannot say I fully understand why the stacktrace points to LinearAlgebra.triu!(...) when differentiating via Enzyme.jl.

Additionally, the program will raise another error stemming from GaussianDistributions.jl log likelihood computation. This can be rectified by instead using Distributions.jl and converting Gaussian to MvNormal.

My updated filter is as follows:

function kalman_filter(
        model::LinearGaussianModel,
        init_state::Gaussian,
        observations::Vector{T}
    ) where {T<:Real}
    log_evidence = zero(T)
    state = init_state

    A = model.transition.ϕ
    Q = model.transition.Σ

    B = model.observation.ϕ
    R = model.observation.Σ

    for obs in observations
        proposed = Gaussian(A*state.μ, A*state.Σ*A' + Q)

        # internals of GaussianDistributions.correct(...)
        residual = [obs] - B*proposed.μ
        S = Symmetric(B*proposed.Σ*B' + R)
        gain = proposed.Σ*B' / S

        # apply correction
        state = Gaussian(
            proposed.μ + gain*residual,
            (I - gain*B)*proposed.Σ*(I - gain*B)' + gain*R*gain'
        )

        # use a different calculation for the logpdf
        log_evidence += Distributions.logpdf(
            MvNormal(zero(residual), S),
            residual
        )
    end

    return log_evidence
end

When I did the calculations previously, I calculated the Kalman gain using the Cholesky decomposed innovations covariance. While this is a more stable computation, Enzyme.jl struggled to differentiate the right division with a Cholesky. @wsmoses if this is something useful, I would be more than happy to make a MWE and open an issue.

1 Like