Hi everyone, nice to meet you
I’m trying to write a custom multivariate model in Julia using banded matrix operations, starting with the simple example of a Gaussian random walk with IID innovations and IID observation noise.
I tried to optimize my code as much as I could, and reached the point where I can’t think of anything else, but I’m sure there’s a ton of things I got wrong, or don’t know about, etc.
Here’s my code below, I’d love to hear your suggestions about things I’m doing wrong, or things I didn’t do and should, etc!
# INITIALIZATION
using Pkg
dependencies = ["LinearAlgebra", "BandedMatrices", "Turing", "Distributions", "Random", "StatsPlots","LazyArrays","DynamicPPL","InteractiveUtils","Bijectors", "Enzyme", "ChainRulesCore","ReverseDiff", "Memoization", "Zygote","ChainRulesTestUtils","PreallocationTools"]
for pkg in dependencies
if !haskey(Pkg.dependencies(), pkg)
Pkg.add(pkg)
end
end
using LinearAlgebra, BandedMatrices
using Turing, Distributions, Random, StatsPlots, LazyArrays
using Bijectors, ChainRulesCore#, Enzyme
#using DynamicPPL, InteractiveUtils
using ReverseDiff, Memoization
#using ChainRulesTestUtils
using PreallocationTools
#using Zygote
Random.seed!(1234)
# DEFINING CUSTOM DISTRIBUTION STRUCT
struct jointGRW{Tσ<:Real, TL<:Real,C} <: ContinuousMultivariateDistribution
"""
Struct containing:
n - number of data points in the Gaussian random walk
σ - standard deviation of IID innovations
L - transformation matrix from states to innovations (banded)
cch - cache for storing the innovations vector
"""
n::Int
σ::Tσ
L::BandedMatrix{TL}
cch::C
end
@memoize function getL(n::Int)
"""
Compiled function to get the constant transformation matrix jointGRW.L without use of global variables
"""
return BandedMatrix{Float64}(-1 => -ones(n-1), 0 => ones(n))
end
@memoize function getCache(n::Int)
"""
Compiled function to get the innovations-storing cache variable jointGRW.cch without use of global variables
"""
return DiffCache(zeros(n))
end
function jointGRW(n::Int, σ::T) where T<:Real
"""
Constructor for the jointGRW struct
"""
#@assert σ ≥ 0 && n > 0 "Parameters must be positive"
return jointGRW(n, σ ,getL(n),getCache(n))
end
Base.length(d::jointGRW) = d.n # length of states/innovations vector
function bvmul!(d::jointGRW,x::AbstractVector)
"""
In-place multiplication of a banded matrix by a vector
Using the cache of a jointGRW struct to reduce allocations
Wrapped for purposes of defining appropriate rrule
"""
# Get a type-appropriate version of the cache (Float64 or Tracked)
tmp = get_tmp(d.cch, x)
# In-place banded matrix multiplication
return mul!(tmp, d.L, x)
end
function ChainRulesCore.rrule(::typeof(bvmul!),d::jointGRW,x::AbstractVector)
"""
rrule for bvmul! operation
only need to get tangent in respect to the states vector x
"""
y = bvmul!(d,x)
function bvmul_pullback(Δy)
#Δx = @thunk(d.L' * Δy)
#Δx = get_tmp(d.cch, Δy)
#mul!(Δx, d.L', Δy)
return NoTangent(), NoTangent(), @thunk(d.L' * Δy)
end
return y, bvmul_pullback
end
# Registering the rrule with ReverseDiff:
ReverseDiff.@grad_from_chainrules bvmul!(d::jointGRW, x::ReverseDiff.TrackedArray)
function Distributions._logpdf(d::jointGRW, x::AbstractVector)
"""
The joint log-probability of the states vector, calculated via linear transformation of the states vector to its corresponding innovations vector - banded matrix-vector product implemented via bvmul!
Transformed vector is then treated as a vector of IID 1D Gaussians
"""
return -d.n * log(d.σ) - 0.5 * sum(abs2, bvmul!(d,x)) / (d.σ^2) - (d.n/2) * log(2π)
end
Distributions.insupport(d::jointGRW, x::AbstractVector) = length(x) == d.n
function Distributions.rand(rng::AbstractRNG, d::jointGRW)
"""
Obtain samples from the GRW with known parameters
"""
return cumsum(randn(rng, d.n) .* d.σ)
end
# Defining bijectors - distribution already unconstrained
Bijectors.VectorBijectors.linked_vec_length(d::jointGRW) = d.n
Bijectors.VectorBijectors.to_linked_vec(::jointGRW) = Bijectors.VectorBijectors.TypedIdentity()
Bijectors.VectorBijectors.from_linked_vec(::jointGRW) = Bijectors.VectorBijectors.TypedIdentity()
@model function jointGRWmodel(y,n)
"""
The Turing model
Exponential priors on innovation and observation noise standard deviations, states vector distributed according to the Gaussian random walk defined by jointGRW, observed data = states + IID Gaussian noise
"""
σ_s ~ Exponential(1)
σ_o ~ Exponential(1)
u ~ jointGRW(n,σ_s)
y ~ MvNormal(u,σ_o)
end
function simulate_grw(n::Int, σ_o,σ_s)
"""
Simulating data
"""
return cumsum(randn(n)*σ_s) + randn(n)*σ_o
end
n = 1000
y = simulate_grw(n,1.0,1.0)
model = jointGRWmodel(y,n)
@time chain = sample(model, NUTS(1000, 0.9, adtype=AutoReverseDiff(compile=true)), MCMCThreads(), 10000, 4, progress=true)
gui(plot(chain[["σ_s","σ_o"]]))
readline()