Turing.jl Memory Issues

Hi All,

I designed a relatively simple hierarchical model for my research. To fit the model I initially wanted to use MCMC (NUTS). The other researchers in the lab tend to use R, so I started with rstan. Rstan worked perfectly fine, but it was challenging to implement in-chain parallelization for MCMC so I switched to Turing.jl.

Unfortunately, Turing.jl has posed a number of surprising challenges. For instance, while I was able to run the model using the Stan’s ‘experimental’ VI function with only 128gb of RAM, the ADVI in Turing.jl capped out the memory in the virtual machine I was using, even when I threw 1TB of RAM at it. I am using the ADVI with autoreversediff. I’ve tried things like moving to Float32, but it seems that it isn’t very well supported.

The model has 50*N parameters, so with my full dataset with approximately N=5000 there are 250k parameters. That’s a lot, but not uncommon for this kind of modelling.

Is there any way to improve memory performance? If not, what other PPLs would you recommend to enable me to fit this model?

Thank you very much in advance!