New to Julia - questions about optimizing my Turing code

Hi everyone, nice to meet you

I’m trying to write a custom multivariate model in Julia using banded matrix operations, starting with the simple example of a Gaussian random walk with IID innovations and IID observation noise.

I tried to optimize my code as much as I could, and reached the point where I can’t think of anything else, but I’m sure there’s a ton of things I got wrong, or don’t know about, etc.

Here’s my code below, I’d love to hear your suggestions about things I’m doing wrong, or things I didn’t do and should, etc!

# INITIALIZATION

using Pkg
dependencies = ["LinearAlgebra", "BandedMatrices", "Turing", "Distributions", "Random", "StatsPlots","LazyArrays","DynamicPPL","InteractiveUtils","Bijectors", "Enzyme", "ChainRulesCore","ReverseDiff", "Memoization", "Zygote","ChainRulesTestUtils","PreallocationTools"]
for pkg in dependencies
    if !haskey(Pkg.dependencies(), pkg)
        Pkg.add(pkg)
    end
end

using LinearAlgebra, BandedMatrices
using Turing, Distributions, Random, StatsPlots, LazyArrays
using Bijectors, ChainRulesCore#, Enzyme
#using DynamicPPL, InteractiveUtils
using ReverseDiff, Memoization
#using ChainRulesTestUtils
using PreallocationTools
#using Zygote

Random.seed!(1234)

# DEFINING CUSTOM DISTRIBUTION STRUCT
struct jointGRW{Tσ<:Real, TL<:Real,C} <: ContinuousMultivariateDistribution
    """
    Struct containing:
    n - number of data points in the Gaussian random walk
    σ - standard deviation of IID innovations
    L - transformation matrix from states to innovations (banded)
    cch - cache for storing the innovations vector
    """
    n::Int
    σ::Tσ
    L::BandedMatrix{TL}
    cch::C
end

@memoize function getL(n::Int)
    """
    Compiled function to get the constant transformation matrix jointGRW.L without use of global variables
    """
    return BandedMatrix{Float64}(-1 => -ones(n-1), 0 => ones(n))
end

@memoize function getCache(n::Int)
    """
    Compiled function to get the innovations-storing cache variable jointGRW.cch without use of global variables
    """
    return DiffCache(zeros(n))
end

function jointGRW(n::Int, σ::T) where T<:Real
    """
    Constructor for the jointGRW struct
    """
    #@assert σ ≥ 0 && n > 0 "Parameters must be positive"
    return jointGRW(n, σ ,getL(n),getCache(n))
end

Base.length(d::jointGRW) = d.n # length of states/innovations vector

function bvmul!(d::jointGRW,x::AbstractVector)
    """
    In-place multiplication of a banded matrix by a vector

    Using the cache of a jointGRW struct to reduce allocations
    
    Wrapped for purposes of defining appropriate rrule
    """
    # Get a type-appropriate version of the cache (Float64 or Tracked)
    tmp = get_tmp(d.cch, x)
    # In-place banded matrix multiplication
    return mul!(tmp, d.L, x)
end

function ChainRulesCore.rrule(::typeof(bvmul!),d::jointGRW,x::AbstractVector)
    """
    rrule for bvmul! operation
    
    only need to get tangent in respect to the states vector x
    """
    y = bvmul!(d,x)
    function bvmul_pullback(Δy)
        #Δx = @thunk(d.L' * Δy)
        #Δx = get_tmp(d.cch, Δy)
        #mul!(Δx, d.L', Δy)

        return NoTangent(), NoTangent(), @thunk(d.L' * Δy)
    end
    return y, bvmul_pullback
end

# Registering the rrule with ReverseDiff:
ReverseDiff.@grad_from_chainrules bvmul!(d::jointGRW, x::ReverseDiff.TrackedArray)

function Distributions._logpdf(d::jointGRW, x::AbstractVector)
    """
    The joint log-probability of the states vector, calculated via linear transformation of the states vector to its corresponding innovations vector - banded matrix-vector product implemented via bvmul!

Transformed vector is then treated as a vector of IID 1D Gaussians
    """
    return -d.n * log(d.σ) - 0.5 * sum(abs2, bvmul!(d,x)) / (d.σ^2) - (d.n/2) * log(2π)
end




Distributions.insupport(d::jointGRW, x::AbstractVector) = length(x) == d.n

function Distributions.rand(rng::AbstractRNG, d::jointGRW)
    """
    Obtain samples from the GRW with known parameters
    """
    return cumsum(randn(rng, d.n) .* d.σ)
end



# Defining bijectors - distribution already unconstrained
Bijectors.VectorBijectors.linked_vec_length(d::jointGRW) = d.n

Bijectors.VectorBijectors.to_linked_vec(::jointGRW) = Bijectors.VectorBijectors.TypedIdentity()
Bijectors.VectorBijectors.from_linked_vec(::jointGRW) = Bijectors.VectorBijectors.TypedIdentity()


@model function jointGRWmodel(y,n)
    """
    The Turing model
    Exponential priors on innovation and observation noise standard deviations, states vector distributed according to the Gaussian random walk defined by jointGRW, observed data = states + IID Gaussian noise
    """
    σ_s ~ Exponential(1)
    σ_o ~ Exponential(1)
    u ~ jointGRW(n,σ_s)
    y ~ MvNormal(u,σ_o)
end

function simulate_grw(n::Int, σ_o,σ_s)
    """
    Simulating data
    """
    return cumsum(randn(n)*σ_s) + randn(n)*σ_o
end

n = 1000
y = simulate_grw(n,1.0,1.0)
model = jointGRWmodel(y,n)


@time chain = sample(model, NUTS(1000, 0.9, adtype=AutoReverseDiff(compile=true)), MCMCThreads(), 10000, 4, progress=true)

gui(plot(chain[["σ_s","σ_o"]]))
readline()