DPMMSubClusters.jl
Package provides an easy, fast and scalable way to perform inference in Dirichlet Process Mixture Models.
Developed from the code of:
Distributed MCMC Inference in Dirichlet Process Mixture Models Using Julia by Dinari et al.
Which is based on the algorithm from:
Parallel Sampling of DP Mixture Models using Sub-Clusters Splits by Chang and Fisher.
The package currently supports Gaussian and Multinomial priors, however adding your own is very easy, and more will come in future releases.
The package is faster than any other available open code (as far as I know), including solutions for DPMM in Matlab, Python and Julia, even when running only on 1 process.
Few examples:
2d Gaussian with ploting
Image segmentation
Running-Saving-Loading-Rerunning model
Installation
The package (latest version) has the following dependencies:
- CatViews
- Clustering (
0.13.3
) - Distributions
- JLD2
- NPZ
- SpecialFunctions
- StatsBase
- LinearAlgebra
- Distributed
- DistributedArrays
- Random
To install, simply:
] add DPMMSubClusters
Usage:
This package is aimed for distributed parallel computing, while working with no workers is possible. Adding more workers, distributed across different machines, is encouraged for increased performance.
It is recommended to use BLAS.set_num_threads(1)
. When working with larger datasets increasing the amount of workers will do the trick, BLAS
multi threading might disturb the multiprocessing, resulting in slower inference.
For all the workers to recognize the package, you must start with @everywhere using DPMMSubClusters
. If you require to set the seed (using the seed
kwarg), add @everywhere using Random
as well.
While being very verstile in the setting and configuration, there are 2 modes which you can work with, either the Basic, which will use mostly predefined configuration, and will take the data as an argument, or Advanced use, which allows more configuration, loading data from file, and saving the model, or running from a saved checkpoint.
Basic
In order to run in the basic mode, use the function:
labels, clusters, weights = fit(all_data::AbstractArray{Float32,2},local_hyper_params::distribution_hyper_params,α_param::Float32;
iters::Int64 = 100, init_clusters::Int64 = 1,seed = nothing, verbose = true, save_model = false, burnout = 20, gt = nothing)
Or, if opting for the default Gaussian weak prior:
labels, clusters, weights = fit(all_data::AbstractArray{Float32,2},α_param::Float32;
iters::Int64 = 100, init_clusters::Int64 = 1,seed = nothing, verbose = true, save_model = false,burnout = 20, gt = nothing)
* note that while we dispatch on Float32
, other numbers will work as well, and will be cast if needed.
Args and Kwargs:
- all_data - The data, should be
DxN
. - local_hyper_params - The prior you plan to use, can be either Multinomial, or
NIW
(example below on how to create one) - α_param - Concetration parameter
- iters - Number of iterations
- seed - Random seed, can also be set seperatly. note that if seting seperatly you must set it on all workers.
- verbose - Printing status on every iteration.
- save_model - If true, will save a checkpoint every 25 iterations, note that if you opt for saving, I recommend the advanced mode.
- burnout - How many iteration before allowing clusters to split/merge, reducing this number will result in faster inference, but with higher variance between the different runs.
- gt - Ground Truth, if supplied will perform
NMI
andVI
tests on every iteration.
Return values:
fit
will return the following:
labels, cluster_params, weights, iteration_time_history, nmi_score_history,likelihood_history, cluster_count_history
Note that weights
does not sum up to 1
, but to 1
minus the weight of the non-instanisated components.
Advanced
In this mode you are required to supply a params file, example for one is the file global_params.jl
.
It includes all the configurable params. Running it is as simple as:
dp = dp_parallel(model_params::String; verbose = true, save_model = true, burnout = 5, gt = nothing)
Will return:
dp, iteration_time_history , nmi_score_history, liklihood_history, cluster_count_history
The returned value dp
is a data structure:
mutable struct dp_parallel_sampling
model_hyperparams::model_hyper_params
group::local_group
end
In which contains the local_group
, another structure:
mutable struct local_group
model_hyperparams::model_hyper_params
points::AbstractArray{Float64,2}
labels::AbstractArray{Int64,1}
labels_subcluster::AbstractArray{Int64,1}
local_clusters::Vector{local_cluster}
weights::Vector{Float64}
end
Note that for data loading the package use NPZ
, which utilize python numpy files. Thus the data files must be pythonic, and be of the shape NxD
.
Additional Functions
Additional function exposed to the user include:
run_model_from_checkpoint(file_name)
: Used to restart a saved run, file_name must point to a valid checkpoint file created during a run of the model. Note that the params files used for running the model initialy must still be available and in the same location, this is true for the data as well.calculate_posterior(model)
: Calculate the posterior of a model, returned fromdp_parallel
.generate_gaussian_data(N::Int64, D::Int64, K::Int64, MixtureVar::Number)
: Randomly generates gaussian data,N
points, of dimensionD
fromK
clusters, withMixtureVar
variance between mixture componenets means. return value ispoints, labels, cluster_means, cluster_covariance
.generate_mnmm_data(N::Int64, D::Int64, K::Int64, trials::Int64)
: Similar to above, just for multinomial data, the return value ispoints, labels, clusters
Toy Example
using DPMMSubClusters
#Generate 10k samples of 2D gaussian data, sampled from 6 random gaussians)
x,y,clusters = generate_gaussian_data(10000,2,6,100.0)
#NIW Hyper Params:
# struct niw_hyperparams <: distribution_hyper_params
# κ::Float32
# m::AbstractArray{Float32}
# ν::Float32
# ψ::AbstractArray{Float32}
# end
hyper_params = DPMMSubClusters.niw_hyperparams(1.0,
zeros(2),
5,
[1 0;0 1])
##Run with hyper params
ret_values= fit(x,hyper_params,10.0, iters = 100)
##Run without hyper params,faster burnout and gt
ret_values= fit(x,10.0, iters = 100,burnout = 10, gt = y)
labels = ret_values[1]
Performance Tips
As mentioned above, it is recommended to use BLAS.set_num_threads(1)
.
The performance increase is not linear with the processes, and on small data sets of lower dimensions adding more processes might even reduce performance.
On that note - Any optimization contributions are very welcomed
Misc
For any questions: dinari@post.bgu.ac.il
Also available here and on Julia slack.
Contributions, feature requests, suggestion etc… are welcomed.
If you use this code for your work, please cite the following:
@inproceedings{Dinari:CCGrid:2019,
title={Distributed {MCMC} Inference in {Dirichlet} Process Mixture Models Using {Julia}},
author={Dinari, Or and Angel, Yu and Freifeld, Oren and Fisher III, John W},
booktitle={International Symposium on Cluster, Cloud and Grid Computing (CCGRID) Workshop on High Performance Machine Learning Workshop},
year={2019}
}