InferenceObjects.InferenceData
, used in ArviZ, takes this approach. But it actually takes it further in not (currently) supporting draws of arbitrary Julia types, since for diagnostics, summary statistics, plotting, and long-term serialization, one generally needs to flatten Julia types into numeric arrays or tables anyways. There are plans to support interconverting between these compact representations and the original Julia type, but it’s tricky.
julia> using ArviZExampleData, LinearAlgebra, Statistics
julia> idata = load_example_data("centered_eight")
InferenceData with groups:
> posterior
> posterior_predictive
> log_likelihood
> sample_stats
> prior
> prior_predictive
> observed_data
> constant_data
julia> idata.posterior
Dataset with dimensions:
Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points,
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and 3 layers:
:mu Float64 dims: Dim{:draw}, Dim{:chain} (500Ă—4)
:theta Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8Ă—500Ă—4)
:tau Float64 dims: Dim{:draw}, Dim{:chain} (500Ă—4)
with metadata Dict{String, Any} with 6 entries:
"created_at" => "2022-10-13T14:37:37.315398"
"inference_library_version" => "4.2.2"
"sampling_time" => 7.48011
"tuning_steps" => 1000
"arviz_version" => "0.13.0.dev0"
"inference_library" => "pymc"
julia> dropdims(mean(idata.posterior.theta; dims=(:draw, :chain)); dims=(:draw, :chain)) # reduce dims
8-element DimArray{Float64,1} theta with dimensions:
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered
and reference dimensions:
Dim{:draw} Sampled{Float64} Float64[249.5] ForwardOrdered Irregular Points,
Dim{:chain} Sampled{Float64} Float64[1.5] ForwardOrdered Irregular Points
"Choate" 6.46006
"Deerfield" 5.02755
"Phillips Andover" 3.93803
"Phillips Exeter" 4.87161
"Hotchkiss" 3.66684
"Lawrenceville" 3.97469
"St. Paul's" 6.58092
"Mt. Hermon" 4.77241
julia> mapslices(normalize, idata.posterior.theta; dims=(:draw, :chain)) # keep dims
8Ă—500Ă—4 DimArray{Float64,3} theta with dimensions:
Dim{:school} Categorical{String} String[Choate, Deerfield, …, St. Paul's, Mt. Hermon] Unordered,
Dim{:draw} Sampled{Int64} Int64[0, 1, …, 498, 499] ForwardOrdered Irregular Points,
Dim{:chain} Sampled{Int64} Int64[0, 1, 2, 3] ForwardOrdered Irregular Points
[:, :, 1]
0 1 2 3 4 … 496 497 498 499
"Choate" 0.0315722 0.0289199 0.0146283 0.0257209 0.0234451 0.0243186 -0.000547943 0.0266568 0.0170699
"Deerfield" 0.0316057 0.0291295 0.0183722 0.0281077 0.0183947 0.0260931 0.00432398 0.0220399 0.0236556
"Phillips Andover" 0.0483347 0.0101484 0.035381 0.0320071 0.022679 0.00721072 0.0225732 -0.0160478 -0.0301376
"Phillips Exeter" 0.0352315 0.0301817 0.0188626 0.0184575 0.0501963 0.00659421 0.011892 0.0100332 0.00861284
"Hotchkiss" 0.0202402 0.0283366 0.03625 0.032876 0.0112359 … 0.00439817 0.0193147 -0.00806875 -0.00182223
"Lawrenceville" 0.0578453 0.00819018 0.02787 0.0238147 0.0411574 0.00631814 0.0238403 -0.00970274 -0.014562
"St. Paul's" 0.0354355 0.0269973 0.020418 0.0276498 0.0303843 0.0116846 0.0132371 0.0144745 0.0203155
"Mt. Hermon" 0.0451373 0.018511 0.0262757 0.0094561 0.0510854 0.0140476 0.00916629 0.0191239 0.0299318
[and 3 more slices...]
I’m interested to see what you come up with.