Get list of parameters from Turing model

How can I get the list of parameters in a Turing model?

I’m trying to generate programmatically a list of initial values to pass with init_params, so I need to know the symbols, in the correct order.

For example, given the following model I’d like to get the tuple (σ, μ) :

@model function mymodel(x)
    σ ~ Exponential()
    μ ~ Normal(2.0, 1.0)
    x .~ Normal(μ, σ)
    
    return μ, σ
end

m = mymodel([1.0, 2.0, 3.0])

ch = sample(m, NUTS(), 1000, discard_adapt=false, init_params=[0.5, 5])
1 Like

I don’t know about a direct way to query the model, but here’s an inefficient way via the Chains object returned by sample:

julia> ch = sample(m,Prior(),1); ch.name_map.parameters
2-element Vector{Symbol}:
 :μ
 :σ
1 Like

This is probably a good approach for now, though at some point this should be made easier.

FWIW in Soss you can do

julia> mymodel = @model n begin
       σ ~ Exponential()
       μ ~ Normal(2.0, 1.0)
       x ~ Normal(μ,σ) |> iid(n)
       end;

julia> arguments(mymodel)
1-element Vector{Symbol}:
 :n

julia> sampled(mymodel)
3-element Vector{Symbol}:
 :σ
 :μ
 :x 

And if you have some observations,

julia> observed(mymodel() | (x = randn(10),))
(:x,)
2 Likes

Thanks for all the answers. Another option I found is DynamicPPL.syms(DynamicPPL.VarInfo(m)) (but I’m not sure how advisable it is to use these unexported DynamicPPL methods).

2 Likes

I think it’s fine – as far as I’m aware those two methods are likely to continue to be around.

1 Like

Nice, this looks so much less hacky than going through sampling (as I proposed earlier)!

Now, curiously, the two approaches give different parameter orders.

julia> DynamicPPL.syms(DynamicPPL.VarInfo(m))
(:σ, :μ)

julia> ch = sample(m,Prior(),1); ch.name_map.parameters
2-element Vector{Symbol}:
 :μ
 :σ

Which of the two orders is relevant for the parameter initialization in your initial post?

Uhh… good question! It seems to be working as expected with the order from DynamicPPL.syms(DynamicPPL.VarInfo(m)). It’s also the order of the data in the chain’s underlying AxisArrays (see ch.value).

I wonder what the ch.name_map.parameters order is about… Just displaying the chain (display(ch)) shows this order in the header, and the other order in the summary stats:

Chains MCMC chain (1000×14×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 10.55 seconds
Compute duration  = 10.55 seconds
parameters        = μ, σ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   e ⋯
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64     ⋯

           σ    1.1765    0.5388     0.0170    0.0246   411.3738    1.0006     ⋯
           μ    1.9466    0.6001     0.0190    0.0182   295.0682    0.9990     ⋯

I had filed a related issue here but the most relevant is probably this one.

1 Like

The “correct” one if you want to work with init_params is DynamicPPL.syms(DynamicPPL.VarInfo(m)), since this works very closely with the model block.

ch.name_map.parameters is constructed later in MCMCChains and isn’t necessarily the best indicator of model parameter order.

2 Likes

There really needs to be a public API for this. A lot of examples have 1 to 10 parameters, and rely on people entering stuff by hand, but I work with models with 2000 parameters easily (let’s say one parameter for each person in a survey or whatever), and if I want to do initialization or describe a complex Gibbs update (ie. update all the discrete parameters with one kind of update, and some of the continuous parameters with a diffusive update, and others with NUTS) then you need to do stuff programmatically.

Of course, the provided solution works, but it should be documented somewhere and if it’s not intended to be the public API why not create a public API.

For example why not do params(modelObject)

7 Likes