Get list of parameters from Turing model

How can I get the list of parameters in a Turing model?

I’m trying to generate programmatically a list of initial values to pass with init_params, so I need to know the symbols, in the correct order.

For example, given the following model I’d like to get the tuple (σ, μ) :

@model function mymodel(x)
    σ ~ Exponential()
    μ ~ Normal(2.0, 1.0)
    x .~ Normal(μ, σ)
    
    return μ, σ
end

m = mymodel([1.0, 2.0, 3.0])

ch = sample(m, NUTS(), 1000, discard_adapt=false, init_params=[0.5, 5])
1 Like

I don’t know about a direct way to query the model, but here’s an inefficient way via the Chains object returned by sample:

julia> ch = sample(m,Prior(),1); ch.name_map.parameters
2-element Vector{Symbol}:
 :μ
 :σ
1 Like

This is probably a good approach for now, though at some point this should be made easier.

FWIW in Soss you can do

julia> mymodel = @model n begin
       σ ~ Exponential()
       μ ~ Normal(2.0, 1.0)
       x ~ Normal(μ,σ) |> iid(n)
       end;

julia> arguments(mymodel)
1-element Vector{Symbol}:
 :n

julia> sampled(mymodel)
3-element Vector{Symbol}:
 :σ
 :μ
 :x 

And if you have some observations,

julia> observed(mymodel() | (x = randn(10),))
(:x,)
3 Likes

Thanks for all the answers. Another option I found is DynamicPPL.syms(DynamicPPL.VarInfo(m)) (but I’m not sure how advisable it is to use these unexported DynamicPPL methods).

2 Likes

I think it’s fine – as far as I’m aware those two methods are likely to continue to be around.

1 Like

Nice, this looks so much less hacky than going through sampling (as I proposed earlier)!

Now, curiously, the two approaches give different parameter orders.

julia> DynamicPPL.syms(DynamicPPL.VarInfo(m))
(:σ, :μ)

julia> ch = sample(m,Prior(),1); ch.name_map.parameters
2-element Vector{Symbol}:
 :μ
 :σ

Which of the two orders is relevant for the parameter initialization in your initial post?

Uhh… good question! It seems to be working as expected with the order from DynamicPPL.syms(DynamicPPL.VarInfo(m)). It’s also the order of the data in the chain’s underlying AxisArrays (see ch.value).

I wonder what the ch.name_map.parameters order is about… Just displaying the chain (display(ch)) shows this order in the header, and the other order in the summary stats:

Chains MCMC chain (1000×14×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 10.55 seconds
Compute duration  = 10.55 seconds
parameters        = μ, σ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   e ⋯
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64     ⋯

           σ    1.1765    0.5388     0.0170    0.0246   411.3738    1.0006     ⋯
           μ    1.9466    0.6001     0.0190    0.0182   295.0682    0.9990     ⋯

I had filed a related issue here but the most relevant is probably this one.

1 Like

The “correct” one if you want to work with init_params is DynamicPPL.syms(DynamicPPL.VarInfo(m)), since this works very closely with the model block.

ch.name_map.parameters is constructed later in MCMCChains and isn’t necessarily the best indicator of model parameter order.

2 Likes

There really needs to be a public API for this. A lot of examples have 1 to 10 parameters, and rely on people entering stuff by hand, but I work with models with 2000 parameters easily (let’s say one parameter for each person in a survey or whatever), and if I want to do initialization or describe a complex Gibbs update (ie. update all the discrete parameters with one kind of update, and some of the continuous parameters with a diffusive update, and others with NUTS) then you need to do stuff programmatically.

Of course, the provided solution works, but it should be documented somewhere and if it’s not intended to be the public API why not create a public API.

For example why not do params(modelObject)

9 Likes

Or coefnames() as in R…

So yeah, years later, is there an “approved” way to get the names of the parameters?

Specifically I’m using a Turing model to construct a custom “tempered model” struct which just uses a LogDensityFunction struct to evaluate the logdensity of the turing model and divides by a given constant temperature. Then I’m slice-sampling this model, and I get a vector of SliceSampling.Transition objects. I want to manually construct a Chains object and I want to know what names I should put for each dimension…

Any updated info?

It’s been a while since I messed with it, but IIRC there are some methods in the MCMCChains package for getting parameter names from output, e.g., names(chains::Chains) or the like.

yeah, except I’m trying to construct a chain so I don’t have one to look at.

It is beginning to look like the easiest way for me though is to simply sample a single sample from the underlying Turing model, and then use the names there to construct the chain from the slice sampling transitions

I see that turing does some stuff in _params_to_array inside src/mcmc/inference.jl to get these names, and it’s nontrivial voodoo.

I don’t know if it preserves the order you want for initialization (it seems to), but you can also call keys(rand(model)), which it at least doesn’t have the overhead of constructing the chains object, so it feels less annoying.

Purely because I was curious myself, and because I sometimes would also like a way to easily extract coefficients without needing to create mock data or actually call sample, here’s a horrible hacky (and probably fragile) function that returns the parameters, given a model and a tuple listing the data variables in the model.

It instantiates the model with data as missing. A rand call creates a Dict with keys that tell us all the variables in the model. Then filter is used to remove the data variables, leaving only the parameters.

# 'model' is a Turing model before data has been passed 
# 'data_args' is a tuple of the names of the data 
#     (e.g., non-parameter) variables in your model

function parameters(model, data_args)

    # Number of variables that aren't parameters
    n_data_params = length(data_args)

    # Filter out the variables that represent data
    filter(a -> ∉(a, data_args), keys(rand((model)(fill(missing, n_data_params)...))))
end