# Conditional Branching of Parameter in Turing.jl

I’m new to the Julia language and to modeling, so I may ask a strange question.

When estimating parameters in Turing.jl, is it possible to conditionally branch for one parameter with another variable?

For example, in the code below, I want to branch the value of the parameter γ depending on the value of var1.

In such a case, I don’t know where and how to write the code.

Please let me know if there is a better way.

``````using Turing, Distributions, DataFrames, DifferentialEquations, DiffEqSensitivity
using MCMCChains, Plots, StatsPlots
using Random
Random.seed!(12);

function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = (α - β*y)x
du[2] = dy = (δ*x - γ)y
end
p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0,1.0]
prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5())
plot(sol)

odedata1 = Array(solve(prob,Tsit5(),saveat=0.1))
odedata2 = odedata1 .+ rand()
odedata3 = odedata1 .+ rand()
odedata = zeros(Float64, 2, 101, 3)
odedata[:,:,1] = odedata1
odedata[:,:,2] = odedata2
odedata[:,:,3] = odedata3

# I don’t know where and how to write the following code.

# var1 = [80, 40, 70]
# if var1 > 60
#     γ = 4
# else
#     γ = 1.5
# end

@model function fitlv(data)
σ ~ InverseGamma(2, 3)
α ~ truncated(Normal(1.5,0.5),0.5,2.5)
β ~ truncated(Normal(1.2,0.5),0,2)
γ ~ truncated(Normal(3.0,0.5),1,4)
δ ~ truncated(Normal(1.0,0.5),0,2)

p = [α,β,γ,δ]
prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
predicted = solve(prob,Tsit5(),saveat=0.1)

for k in 1:ndims(data)
for i = 1:length(predicted)
data[:,i,k] ~ MvNormal(predicted[i], σ)
end
end
end

model = fitlv(odedata)
chain = sample(model, NUTS(.65),1000)
plot(chain)

``````

Just stick that into the model.

1 Like

Your activities have helped me a lot.

I’ve tried many things, but it didn’t work.

For example, in the following code, I removed γ from the parameter and wrote a conditional branch in the model, but it doesn’t seem to be working properly.

I think I’ve made a fundamental mistake.
I would appreciate it if you could show me a concrete method.

``````using Turing, Distributions, DifferentialEquations
using MCMCChains, Plots, StatsPlots
using Random
Random.seed!(12);

function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ = p
γ = 3.0
du[1] = dx = (α - β*y)x
du[2] = dy = (δ*x - γ)y
end
p = [1.5, 1.1, 1.0]
u0 = [1.0,1.0]
prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5())
plot(sol)

odedata1 = Array(solve(prob,Tsit5(),saveat=0.1))
odedata2 = odedata1 .+ rand()
odedata3 = odedata1 .+ rand()
odedata = zeros(Float64, 2, 101, 3)
odedata[:,:,1] = odedata1
odedata[:,:,2] = odedata2
odedata[:,:,3] = odedata3

@model function fitlv(data)
σ ~ InverseGamma(2, 3)
α ~ truncated(Normal(1.5,0.5),0.5,2.5)
β ~ truncated(Normal(1.2,0.5),0,2)
# γ ~ truncated(Normal(3.0,0.5),1,4)
δ ~ truncated(Normal(1.0,0.5),0,2)

var1 = [80, 40, 70]
γ = zeros(Float64, length(var1))

for l in 1:length(var1)
if var1[l] > 60
γ[l] = 5.0
else
γ[l] = 2.0
end
end

p = [α,β,δ]
prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
predicted = solve(prob,Tsit5(),saveat=0.1)

for k in 1:ndims(data)
for i = 1:length(predicted)
data[:,i,k] ~ MvNormal(predicted[i], σ)
end
end
end

model = fitlv(odedata)
chain = sample(model, NUTS(.65),1000)
plot(chain)

``````

I think you meant:

``````    for l in 1:length(var1)
if var1[l] > 60
γ = 5.0
else
γ = 2.0
end
end

p = [α,β,δ,γ]
prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
``````

Under the hood of `Turing.jl`, the parsing of model definitions is done by `DynamicPPL.jl`.

In practice, starting from the definition of your model, a compiler looks for any expressions like `LHS ~ RHS` and replace it with a function that returns the value of your random variable and update the metadata of your model, sampler, likelihood, etc…

So as long as your are writing proper code, you can write (almost*) anything you want next to random assignements `LHS ~ RHS`.

(* : Many definitions of a random variable with the same symbol is overwriting the same field in the metadata, definitions are not position-dependent in the code)

Example:

``````julia> using DynamicPPL

julia> @macroexpand @model function test()
a ~ Normal()
newfunc() = anyfunc()
a = fct(a)+b
if a > 0
return "out"
else
b ~ Normal(a)
end
end
quote
\$(Expr(:meta, :doc))
function test(; )
var"##evaluator#271" = ((_rng::Random.AbstractRNG, _model::Model, _varinfo::AbstractVarInfo, _sampler::AbstractMCMC.AbstractSampler, _context::DynamicPPL.AbstractContext)->begin
begin
#= REPL[6]:2 =#
begin
var"##tmpright#263" = Normal()
var"##tmpright#263" isa Union{Distributions.Distribution, AbstractVector{<:Distributions.Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
var"##vn#265" = a
var"##inds#266" = ()
a = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#263", var"##vn#265", var"##inds#266", _varinfo)
end
#= REPL[6]:4 =#
newfunc() = begin
#= REPL[6]:4 =#
anyfunc()
end
#= REPL[6]:6 =#
a = fct(a) + b
#= REPL[6]:7 =#
if a > 0
#= REPL[6]:8 =#
return "out"
else
var"##tmpright#267" = Normal(a)
var"##tmpright#267" isa Union{Distributions.Distribution, AbstractVector{<:Distributions.Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
var"##vn#269" = b
var"##inds#270" = ()
b = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#267", var"##vn#269", var"##inds#270", _varinfo)
end
end
end)
return (Model)(:test, var"##evaluator#271", NamedTuple(), NamedTuple())
end
end
``````

So as you can see here, your model definition (model function definition) is transformed into a function that outputs a `Model` struct with as internal evaluator your initial model function where all random assignement have been replaced with a block of code that

• test if you are actually assigning a proper `Distribution` for your random variables
• looks for indices in case you are assigning value to an indexable variable
• use `DynamicPPL.tilde_assume` which outputs your RVs, update the sampler, the context and the VarInfo (the latter is the metadata of your model)
1 Like

I have tried many things, but I may have to use MonteCarloProblem (EnsembleProblem ?) in order to reach my goal.

In this example, I want to map odedata to var1.
(i.e., I want to change the value of γ for each odedata.)

I’m going to try to figure out how to write for a while.

Thank you so much.

I’m a beginner, but I think I understand a lot better.

Thank you very much.

1 Like

I’m not sure what your question is then

Sorry.

I guess I don’t fully understand my problem myself.

I’ll have to rethink it myself.

Thank you.

Another funny example which shows that `DynamicPPL.jl` takes the goal of replacing ALL appearences of `LHS ~ RHS` serious:

``````julia> @model function test()
:(a~b)
end
test (generic function with 1 method)

julia> vi=VarInfo();

julia> test()(vi)
quote
var"##tmpright#268" = b
var"##tmpright#268" isa Union{Distributions.Distribution, AbstractVector{<:Distributions.Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
var"##vn#270" = a
var"##inds#271" = ()
a = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#268", var"##vn#270", var"##inds#271", _varinfo)
end
``````
1 Like

I learned a lot and it’s interesting.
Thank you.