MAP Optimization and RAM usage related to autodiff

Hi there probabilistic programming folks - I’ve enjoyed reading the posts here and now have one of my own. I am trying to estimate the MAP of a multivariate linear model with a mixture model component using Turing (will be made clear in the example code). Optimization with LBFGS / ConjugateGradient works (in the sense of getting a reasonable result), but uses a tremendous amount of RAM usage (~5 Gb for the small example below). If I rerun in the same session, RAM usage continues to climb. If I use an optimization method that does not require gradients (e.g,. NelderMead()), the heavy RAM usage isn’t issue, I so presume this is related to autodiff.

Would be grateful for any thoughts here on what is going on. I suspect much of it is due to the line that increments the joint model lob probability due to the mixture model aspect. Running @code_warntype on the model does result in some red (type unstable) text, though I am not entirely sure how to modify the code to reduce the type instability.

using Turing
using Distributions
using ReverseDiff
using Memoization
using LinearAlgebra
using StatsFuns
using Optim

function simulate()
    N_SAMPLES = 100
    N_GENES = 100
    N_DONORS = 3
    N_INTERVENTIONS = 10
    N_GENE_MASK = 5

    α = .5
    G = rand(Bernoulli(.3), N_GENES, N_GENE_MASK)
    dists = [MvNormal(zeros(N_GENES), α .* G[:, i]) for i in 1:N_GENE_MASK]

    β = zeros(N_INTERVENTIONS, N_GENES)
    for i in 1:N_INTERVENTIONS
        z = rand(Categorical(N_GENE_MASK))
        β[i, :] = rand(dists[z])
    end

    D = rand(Bernoulli(.2), N_SAMPLES, N_DONORS)
    ψ = rand(Normal(0, .3), N_DONORS)
    X = rand(Bernoulli(.1), N_SAMPLES, N_INTERVENTIONS)
    ϵ = rand(Normal(0, 0.1), N_SAMPLES, N_GENES)
    Y = rand(X * β .+ D * ψ .+ ϵ, N_SAMPLES, N_GENES)

    return Y, Int8.(D), Int8.(X), Int8.(G)
end

function logpMGaussianMixture(x, dists, w::AbstractVector)
    logw = log.(w)
    K = length(logw)
    N = size(x, 1)
    sum(logsumexp(logw[k] + logpdf(dists[k], x) for k in 1:K))
end
@model function marginalized_bipartite_model(
    Y, 
    D, 
    X, 
    G, 
    ::Type{T} = Float64) where {T}

    N_GENES::Int64 = size(Y, 2)
    N_INTERVENTIONS::Int64 = (size(X, 2))
    N_DONORS::Int64 = size(D, 2)
    N_SAMPLES::Int64 = size(Y, 1)
    N_MASKS::Int64 = size(G, 2)


    w ~ Dirichlet(N_MASKS, 1.0)
    α ~ Normal(-1, 1)
    covs = [sqrt(.1) .+ exp(α) .* G[:, z] for z in 1:N_MASKS]
    μ = zeros(N_GENES)
    dists = [MvNormal(μ, covs[k]) for k in 1:N_MASKS]

    β ~ filldist(Normal(0, 1.0), N_INTERVENTIONS, N_GENES)

    Turing.@addlogprob! sum([logpMGaussianMixture(view(β, i, :), dists, w) for i in 1:N_INTERVENTIONS])

    ψ ~ filldist(Normal(2, 3), N_DONORS)
    σ ~ Exponential(0.2)

    d = D * ψ
    for j in 1:N_GENES
        Y[:, j] ~ MvNormal(view(X * β, :, j) .+ d, σ)
    end

end



function run_model(counts_mat::Matrix{Float64}, donors_mat::Matrix{Int8}, interventions_mat::Matrix{Int8}, gene_mat::Matrix{Int8})

    Turing.emptyrdcache()
    Turing.setadbackend(:reversediff)
    Turing.setrdcache(:true)

    model = marginalized_bipartite_model(
        view(counts_mat, :, 1:100),
        view(donors_mat, :, :),
        view(interventions_mat, :, :),
        view(gene_mat, 1:100, 1:5)
    )

    map_estimate = optimize(
                    model,
                    MAP(),       
                    ConjugateGradient(),
                    # LBFGS(;m = 3),
                    # SimulatedAnnealing(),
                    Optim.Options(
                         f_tol = 1e-3,
                         g_tol = 1e-2,
                         iterations = 50,
                         store_trace = false,
                         show_trace = true,
                         show_every = 3
                    ) 
                )

    return map_estimate, model
end


sims = simulate()
@time result, model = run_model(sims...)

Thanks very much,
Josh