Consider the simple problem of MLE on a linear Gaussian state space model using a Kalman filter. When properly designed, automatic differentiation should be a trivial task for achieving a log score; and for most backends, it is.
Although my code seems type stable, Enzyme.jl refuses to compute the gradients with strict aliasing turned on. This is also compounded by some BLAS warnings noting further performance degradation. Meanwhile, ForwardDiff.jl and Zygote.jl efficiently produce gradients without a hitch.
Without setting Enzyme.API.strictAliasing!(false)
, I get the following error:
ERROR: Enzyme compilation failed due to illegal type analysis.
This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].
Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.
To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Caused by:
Stacktrace:
[1] triu!
@ ~/.julia/juliaup/julia-1.11.0+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:441
[2] getproperty
@ ~/.julia/juliaup/julia-1.11.0+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/qr.jl:497
Stacktrace:
[1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/errors.jl:319
[2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/errors.jl:210
[3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/DL6fx/src/api.jl:268
[4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:1690
[5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:4534
[6] codegen
@ ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:3337 [inlined]
[7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5386
[8] _thunk
@ ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5386 [inlined]
[9] cached_compilation
@ ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5438 [inlined]
[10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5549
[11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/DL6fx/src/compiler.jl:5734
[12] autodiff
@ ~/.julia/packages/Enzyme/DL6fx/src/Enzyme.jl:485 [inlined]
[13] autodiff
@ ~/.julia/packages/Enzyme/DL6fx/src/Enzyme.jl:524 [inlined]
[14] macro expansion
@ ~/.julia/packages/Enzyme/DL6fx/src/sugar.jl:275 [inlined]
[15] gradient(rm::ReverseMode{…}, f::typeof(logℓ), x::Vector{…}, args::Const{…})
@ Enzyme ~/.julia/packages/Enzyme/DL6fx/src/sugar.jl:263
[16] top-level scope
@ ~/code/state_space_models/SSMProblems-AD/examples/data-generating-process/mwe.jl:97
Some type information was truncated. Use `show(err)` to see complete types.
When strict aliasing is turned off, I get proper gradients, but at a snails pace. This may be exacerbated by the following BLAS warning in both instances:
┌ Warning: Using fallback BLAS replacements for (["dasum_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/GnbhK/src/utils.jl:59
Although, this warning seems harmless since I’m not running this in parallel.
For replication, my (not-so) minimal working example is as follows:
using LinearAlgebra
using Enzyme
using GaussianDistributions
using Random
## MODEL DEFINITION ###########################################################
struct LinearGaussianProcess{
T<:Real,
ΦT<:AbstractMatrix{T},
ΣT<:AbstractMatrix{T}
}
ϕ::ΦT
Σ::ΣT
function LinearGaussianProcess(ϕ::ΦT, Σ::ΣT) where {
T<:Real,
ΦT<:AbstractMatrix{T},
ΣT<:AbstractMatrix{T}
}
@assert size(ϕ,1) == size(Σ,1) == size(Σ,2)
return new{T, ΦT, ΣT}(ϕ, Σ)
end
end
struct LinearGaussianModel{
ΘT<:Real,
TT<:LinearGaussianProcess{ΘT},
OT<:LinearGaussianProcess{ΘT}
}
transition::TT
observation::OT
dims::Tuple{Int, Int}
end
## KALMAN FILTER ##############################################################
function kalman_filter(
model::LinearGaussianModel,
init_state::Gaussian,
observations::Vector{T}
) where {T<:Real}
log_evidence = zero(T)
particle = init_state
A = model.transition.ϕ
Q = model.transition.Σ
B = model.observation.ϕ
R = model.observation.Σ
for obs in observations
particle = let μ = particle.μ, Σ = particle.Σ
Gaussian(A*μ, A*Σ*A' + Q)
end
particle, residual, S = GaussianDistributions.correct(
particle,
Gaussian([obs], R), B
)
log_evidence += GaussianDistributions.logpdf(
Gaussian(zero(residual), Symmetric(S)),
residual
)
end
return log_evidence
end
## DEMONSTRATION ##############################################################
# θ should be a single element vector for this demonstration
function build_model(θ::Vector{T}) where {T}
trans = LinearGaussianProcess(T[1;;], Diagonal(θ))
obs = LinearGaussianProcess(T[1;;], Diagonal(T[1]))
return LinearGaussianModel(trans, obs, (1,1))
end
# log likelihood function
function logℓ(θ::Vector{T}, data) where {T<:Real}
model = build_model(θ)
init_state = Gaussian(zeros(T, 1), Diagonal(ones(T, 1)))
return kalman_filter(model, init_state, data)
end
# data generation (with unit covariance)
rng = MersenneTwister(1234)
data = cumsum(randn(rng, 100)) .+ randn(rng, 100)
# ensure that log likelihood looks stable
logℓ([1.0], data)
# calculate gradients
gradient(Reverse, logℓ, [1.0], Const(data))
I use GaussianDistributions.jl for the sake of brevity, but the same issues are present when I do the algebra myself.
Any help is appreciated. I am baffled.