Conditional Branching of Parameter in Turing.jl

I’m new to the Julia language and to modeling, so I may ask a strange question.

When estimating parameters in Turing.jl, is it possible to conditionally branch for one parameter with another variable?

For example, in the code below, I want to branch the value of the parameter γ depending on the value of var1.

In such a case, I don’t know where and how to write the code.

Please let me know if there is a better way.

I would appreciate your advice.
Thank you for your help.

using Turing, Distributions, DataFrames, DifferentialEquations, DiffEqSensitivity
using MCMCChains, Plots, StatsPlots
using Random
Random.seed!(12);

function lotka_volterra(du,u,p,t)
    x, y = u
    α, β, δ, γ = p
    du[1] = dx = (α - β*y)x
    du[2] = dy = (δ*x - γ)y
end
p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0,1.0]
prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5())
plot(sol)

odedata1 = Array(solve(prob,Tsit5(),saveat=0.1))
odedata2 = odedata1 .+ rand()
odedata3 = odedata1 .+ rand()
odedata = zeros(Float64, 2, 101, 3)
odedata[:,:,1] = odedata1
odedata[:,:,2] = odedata2
odedata[:,:,3] = odedata3

# I don’t know where and how to write the following code.

# var1 = [80, 40, 70]
# if var1 > 60
#     γ = 4
# else
#     γ = 1.5
# end

Turing.setadbackend(:forwarddiff)

@model function fitlv(data)
    σ ~ InverseGamma(2, 3)
    α ~ truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ truncated(Normal(1.2,0.5),0,2)
    γ ~ truncated(Normal(3.0,0.5),1,4)
    δ ~ truncated(Normal(1.0,0.5),0,2)

    p = [α,β,γ,δ]
    prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
    predicted = solve(prob,Tsit5(),saveat=0.1)

    for k in 1:ndims(data)
        for i = 1:length(predicted)
            data[:,i,k] ~ MvNormal(predicted[i], σ)
        end
    end
end

model = fitlv(odedata)
chain = sample(model, NUTS(.65),1000)
plot(chain)

Just stick that into the model.

1 Like

Thank you so much for your advice.
Your activities have helped me a lot.

I’ve tried many things, but it didn’t work.

For example, in the following code, I removed γ from the parameter and wrote a conditional branch in the model, but it doesn’t seem to be working properly.

I think I’ve made a fundamental mistake.
I would appreciate it if you could show me a concrete method.

using Turing, Distributions, DifferentialEquations
using MCMCChains, Plots, StatsPlots
using Random
Random.seed!(12);

function lotka_volterra(du,u,p,t)
    x, y = u
    α, β, δ = p
    γ = 3.0
    du[1] = dx = (α - β*y)x
    du[2] = dy = (δ*x - γ)y
end
p = [1.5, 1.1, 1.0]
u0 = [1.0,1.0]
prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5())
plot(sol)

odedata1 = Array(solve(prob,Tsit5(),saveat=0.1))
odedata2 = odedata1 .+ rand()
odedata3 = odedata1 .+ rand()
odedata = zeros(Float64, 2, 101, 3)
odedata[:,:,1] = odedata1
odedata[:,:,2] = odedata2
odedata[:,:,3] = odedata3

Turing.setadbackend(:forwarddiff)

@model function fitlv(data)
    σ ~ InverseGamma(2, 3)
    α ~ truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ truncated(Normal(1.2,0.5),0,2)
    # γ ~ truncated(Normal(3.0,0.5),1,4)
    δ ~ truncated(Normal(1.0,0.5),0,2)

    var1 = [80, 40, 70]
    γ = zeros(Float64, length(var1))

    for l in 1:length(var1)
        if var1[l] > 60
            γ[l] = 5.0
        else
            γ[l] = 2.0
        end
    end

    p = [α,β,δ]
    prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
    predicted = solve(prob,Tsit5(),saveat=0.1)

    for k in 1:ndims(data)
        for i = 1:length(predicted)
            data[:,i,k] ~ MvNormal(predicted[i], σ)
        end
    end
end

model = fitlv(odedata)
chain = sample(model, NUTS(.65),1000)
plot(chain)

I think you meant:

    for l in 1:length(var1)
        if var1[l] > 60
            γ = 5.0
        else
            γ = 2.0
        end
    end

    p = [α,β,δ,γ]
    prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)

Under the hood of Turing.jl, the parsing of model definitions is done by DynamicPPL.jl.

In practice, starting from the definition of your model, a compiler looks for any expressions like LHS ~ RHS and replace it with a function that returns the value of your random variable and update the metadata of your model, sampler, likelihood, etc…

So as long as your are writing proper code, you can write (almost*) anything you want next to random assignements LHS ~ RHS.

(* : Many definitions of a random variable with the same symbol is overwriting the same field in the metadata, definitions are not position-dependent in the code)

Example:

julia> using DynamicPPL

julia> @macroexpand @model function test()
           a ~ Normal()
           newfunc() = anyfunc()
           a = fct(a)+b
           if a > 0
               return "out"
           else
               b ~ Normal(a)
           end
       end
quote
    $(Expr(:meta, :doc))
    function test(; )
        var"##evaluator#271" = ((_rng::Random.AbstractRNG, _model::Model, _varinfo::AbstractVarInfo, _sampler::AbstractMCMC.AbstractSampler, _context::DynamicPPL.AbstractContext)->begin
                    begin
                        #= REPL[6]:2 =#
                        begin
                            var"##tmpright#263" = Normal()
                            var"##tmpright#263" isa Union{Distributions.Distribution, AbstractVector{<:Distributions.Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
                            var"##vn#265" = a
                            var"##inds#266" = ()
                            a = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#263", var"##vn#265", var"##inds#266", _varinfo)
                        end
                        #= REPL[6]:4 =#
                        newfunc() = begin
                                #= REPL[6]:4 =#
                                anyfunc()
                            end
                        #= REPL[6]:6 =#
                        a = fct(a) + b
                        #= REPL[6]:7 =#
                        if a > 0
                            #= REPL[6]:8 =#
                            return "out"
                        else
                            var"##tmpright#267" = Normal(a)
                            var"##tmpright#267" isa Union{Distributions.Distribution, AbstractVector{<:Distributions.Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
                            var"##vn#269" = b
                            var"##inds#270" = ()
                            b = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#267", var"##vn#269", var"##inds#270", _varinfo)
                        end
                    end
                end)
        return (Model)(:test, var"##evaluator#271", NamedTuple(), NamedTuple())
    end
end

So as you can see here, your model definition (model function definition) is transformed into a function that outputs a Model struct with as internal evaluator your initial model function where all random assignement have been replaced with a block of code that

  • test if you are actually assigning a proper Distribution for your random variables
  • looks for indices in case you are assigning value to an indexable variable
  • use DynamicPPL.tilde_assume which outputs your RVs, update the sampler, the context and the VarInfo (the latter is the metadata of your model)
1 Like

Thank you for your advice.

I have tried many things, but I may have to use MonteCarloProblem (EnsembleProblem ?) in order to reach my goal.

In this example, I want to map odedata to var1.
(i.e., I want to change the value of γ for each odedata.)

I’m going to try to figure out how to write for a while.

Thank you so much.

Thank you for your thoughtful advice.

I’m a beginner, but I think I understand a lot better.

Thank you very much.

1 Like

I’m not sure what your question is then :man_shrugging:

Sorry.

I guess I don’t fully understand my problem myself.

I’ll have to rethink it myself.

Thank you.

Another funny example which shows that DynamicPPL.jl takes the goal of replacing ALL appearences of LHS ~ RHS serious:

julia> @model function test()
           :(a~b)
       end
test (generic function with 1 method)

julia> vi=VarInfo();

julia> test()(vi)
quote
    var"##tmpright#268" = b
    var"##tmpright#268" isa Union{Distributions.Distribution, AbstractVector{<:Distributions.Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
    var"##vn#270" = a
    var"##inds#271" = ()
    a = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#268", var"##vn#270", var"##inds#271", _varinfo)
end
1 Like

I learned a lot and it’s interesting.
Thank you.