New to Julia - questions about optimizing my Turing code

Hi everyone, nice to meet you

I’m trying to write a custom multivariate model in Julia using banded matrix operations, starting with the simple example of a Gaussian random walk with IID innovations and IID observation noise.

I tried to optimize my code as much as I could, and reached the point where I can’t think of anything else, but I’m sure there’s a ton of things I got wrong, or don’t know about, etc.

Here’s my code below, I’d love to hear your suggestions about things I’m doing wrong, or things I didn’t do and should, etc!

# INITIALIZATION

using Pkg
dependencies = ["LinearAlgebra", "BandedMatrices", "Turing", "Distributions", "Random", "StatsPlots","LazyArrays","DynamicPPL","InteractiveUtils","Bijectors", "Enzyme", "ChainRulesCore","ReverseDiff", "Memoization", "Zygote","ChainRulesTestUtils","PreallocationTools"]
for pkg in dependencies
    if !haskey(Pkg.dependencies(), pkg)
        Pkg.add(pkg)
    end
end

using LinearAlgebra, BandedMatrices
using Turing, Distributions, Random, StatsPlots, LazyArrays
using Bijectors, ChainRulesCore#, Enzyme
#using DynamicPPL, InteractiveUtils
using ReverseDiff, Memoization
#using ChainRulesTestUtils
using PreallocationTools
#using Zygote

Random.seed!(1234)

# DEFINING CUSTOM DISTRIBUTION STRUCT
struct jointGRW{Tσ<:Real, TL<:Real,C} <: ContinuousMultivariateDistribution
    """
    Struct containing:
    n - number of data points in the Gaussian random walk
    σ - standard deviation of IID innovations
    L - transformation matrix from states to innovations (banded)
    cch - cache for storing the innovations vector
    """
    n::Int
    σ::Tσ
    L::BandedMatrix{TL}
    cch::C
end

@memoize function getL(n::Int)
    """
    Compiled function to get the constant transformation matrix jointGRW.L without use of global variables
    """
    return BandedMatrix{Float64}(-1 => -ones(n-1), 0 => ones(n))
end

@memoize function getCache(n::Int)
    """
    Compiled function to get the innovations-storing cache variable jointGRW.cch without use of global variables
    """
    return DiffCache(zeros(n))
end

function jointGRW(n::Int, σ::T) where T<:Real
    """
    Constructor for the jointGRW struct
    """
    #@assert σ ≥ 0 && n > 0 "Parameters must be positive"
    return jointGRW(n, σ ,getL(n),getCache(n))
end

Base.length(d::jointGRW) = d.n # length of states/innovations vector

function bvmul!(d::jointGRW,x::AbstractVector)
    """
    In-place multiplication of a banded matrix by a vector

    Using the cache of a jointGRW struct to reduce allocations
    
    Wrapped for purposes of defining appropriate rrule
    """
    # Get a type-appropriate version of the cache (Float64 or Tracked)
    tmp = get_tmp(d.cch, x)
    # In-place banded matrix multiplication
    return mul!(tmp, d.L, x)
end

function ChainRulesCore.rrule(::typeof(bvmul!),d::jointGRW,x::AbstractVector)
    """
    rrule for bvmul! operation
    
    only need to get tangent in respect to the states vector x
    """
    y = bvmul!(d,x)
    function bvmul_pullback(Δy)
        #Δx = @thunk(d.L' * Δy)
        #Δx = get_tmp(d.cch, Δy)
        #mul!(Δx, d.L', Δy)

        return NoTangent(), NoTangent(), @thunk(d.L' * Δy)
    end
    return y, bvmul_pullback
end

# Registering the rrule with ReverseDiff:
ReverseDiff.@grad_from_chainrules bvmul!(d::jointGRW, x::ReverseDiff.TrackedArray)

function Distributions._logpdf(d::jointGRW, x::AbstractVector)
    """
    The joint log-probability of the states vector, calculated via linear transformation of the states vector to its corresponding innovations vector - banded matrix-vector product implemented via bvmul!

Transformed vector is then treated as a vector of IID 1D Gaussians
    """
    return -d.n * log(d.σ) - 0.5 * sum(abs2, bvmul!(d,x)) / (d.σ^2) - (d.n/2) * log(2π)
end




Distributions.insupport(d::jointGRW, x::AbstractVector) = length(x) == d.n

function Distributions.rand(rng::AbstractRNG, d::jointGRW)
    """
    Obtain samples from the GRW with known parameters
    """
    return cumsum(randn(rng, d.n) .* d.σ)
end



# Defining bijectors - distribution already unconstrained
Bijectors.VectorBijectors.linked_vec_length(d::jointGRW) = d.n

Bijectors.VectorBijectors.to_linked_vec(::jointGRW) = Bijectors.VectorBijectors.TypedIdentity()
Bijectors.VectorBijectors.from_linked_vec(::jointGRW) = Bijectors.VectorBijectors.TypedIdentity()


@model function jointGRWmodel(y,n)
    """
    The Turing model
    Exponential priors on innovation and observation noise standard deviations, states vector distributed according to the Gaussian random walk defined by jointGRW, observed data = states + IID Gaussian noise
    """
    σ_s ~ Exponential(1)
    σ_o ~ Exponential(1)
    u ~ jointGRW(n,σ_s)
    y ~ MvNormal(u,σ_o)
end

function simulate_grw(n::Int, σ_o,σ_s)
    """
    Simulating data
    """
    return cumsum(randn(n)*σ_s) + randn(n)*σ_o
end

n = 1000
y = simulate_grw(n,1.0,1.0)
model = jointGRWmodel(y,n)


@time chain = sample(model, NUTS(1000, 0.9, adtype=AutoReverseDiff(compile=true)), MCMCThreads(), 10000, 4, progress=true)

gui(plot(chain[["σ_s","σ_o"]]))
readline()

Before even scrolling through the code, maybe you would want to save project and manifest files instead of doing that Pkg.add loop. I can’t even imagine where you got that loop; it’s using an undocumented and unstable Pkg.dependencies(), and the list of 17 package names is not precise enough in several ways for practical reproducibility.

I can’t comment on the technical stuff (not my area of expertise) but here are some general tips:

  • in Julia docstrings stand before the function. You put them as first statement inside the function which does nothing.
  • I am sceptical of your 2 memoizations. Caching the BandedMatrix is probably fine but not really necessary. Whereas caching DiffCache is potentially dangerous (consider the case where you have many of these objects and do things to them in parallel threads. Then they would all write to the same cache likely producing nonesense and would be very hard to find out why). So if I were you, I would just not memoize anything here. It is likely not worth if not harmful.
  • structs usually start with an UpperCase character

For learning about environments, I can highly recommend ‘Modern Julia Workflows’

Probably the next step is to then split the definitions out into your own local package (see next section in the link above). That gives you greater reusability and let’s you benefit from precompilation more.

Thank you, yeah I’ll look into that.
My background’s mostly mathematical, not computational, so it’s a bit hard doing things from scratch in a language I never tried until last week. Since posting this I picked up at least 1 more embarrassing mistake I did in that code besides what you just pointed out, and probably still missed 20 others.

I was mostly reading bits and pieces of tutorials and documentation and trying to implement stuff, just to get something actually done and a lay of the land on what I need to learn more deeply.

Thank you for your advice!

Thank you for your advice!

Yeah I went for memoization because when I used allocating operations (* in this case), the sampling was very slow and had a ton of allocations. Essentially, the calculations in bvmul and its corresponding rrule pullback need to happen a lot of times, as fast as possible, so I do want to avoid allocations as much as I can. Especially given that the size of the arrays involved grow with the size of data, n.

Since the banded matrix doesn’t change during sampling, and the array sizes of the calculation result and relevant gradient don’t change, I wanted to cache them. Also because I don’t want to use global variables for this…

With the memoization it’s running faster somehow, but still a lot of allocations, and still only half as fast as the same model implemented in PyMC. Plus a bunch of lock conflicts, although the inference result itself looks OK.

Right now I’m thinking of still caching things, but going through the documentation more deeply first and learning how to do it in a thread-safe way. Maybe create a different struct entirely for everything that holds the “cached” stuff…

Could you try disabling the MCMC progress bar and see if that helps?

The only other immediate thing I’d suggest would be to try with a different AD backend (seems like from the commented out bits of code you’ve already looked into this…?). Compiled ReverseDiff is usually pretty good but not always the best. And for NUTS, the speed of gradient evaluation is pretty much going to be the most important factor in determining performance.

I was thinking of using Enzyme eventually? I read that it’s very fast. But Enzyme seems to require a whole bunch of coding “best practices” in order to work, which I still need to learn a lot about. So, I settled on first getting things working as well as I can with ReverseDiff, and only then see about migrating to Enzyme.

Also, the lock conflicts still happen with the progress bar disabled, I think it’s just because I memoized cache in the most cursed thread-unsafe way imaginable

Thank you for your advice, please tell me if you catch any other mistakes or possible improvements!

Alright. I think I agree with @abraemer I would not bother with the caching here. A global variable is not all that bad, you can also declare it as const if it is indeed immutable.

Alternatively, you can pass it as an argument into the model function and then thread it through the constructors, so that within the model body it’s a local variable.

Each time you evaluate the model it has to call jointGRW(...) and that means reading from the cache, and the AD backend has to differentiate through all of that memoisation code as well.