Turing.jl for Causal Inference model

I am trying to translate the model from McElreath’s Causal Inference workshop (youtube, github) to Turing.jl and while the results seem to be the same I get much worse performance with Turing (3 sec. vs 110 sec.) so I was wondering if I am doing something wrong here? I appreciate any suggestions!

using Distributions,
    DynamicHMC,
    GLM,
    Memoization,    
    Random,
    ReverseDiff,
    RCall,
    StatsBase,
    StatsPlots,
    Turing


## Rethinking Version
R"""
set.seed(1908)
N <- 200 # number of pairs
U <- rnorm(N) # simulate confounds
# birth order and family sizes
B1 <- rbinom(N,size=1,prob=0.5) # 50% first borns
M <- rnorm( N , 2*B1 + U )
B2 <- rbinom(N,size=1,prob=0.5)
D <- rnorm( N , 2*B2 + U + 0*M ) # change the 0 to turn on causal influence of mom
library(rethinking)
library(cmdstanr)
dat <- list(N=N,M=M,D=D,B1=B1,B2=B2)
set.seed(1908)
flbi <- ulam(
    alist(
        # mom model
            M ~ normal( mu , sigma ),
            mu <- a1 + b*B1 + k*U[i],
        # daughter model
            D ~ normal( nu , tau ),
            nu <- a2 + b*B2 + m*M + k*U[i],
        # B1 and B2
            B1 ~ bernoulli(p),
            B2 ~ bernoulli(p),
        # unmeasured confound
            vector[N]:U ~ normal(0,1),
        # priors
            c(a1,a2,b,m) ~ normal( 0 , 0.5 ),
            c(k,sigma,tau) ~ exponential( 1 ),
            p ~ beta(2,2)
    ), data=dat , chains=4 , cores=4 , iter=2000 , cmdstan=TRUE )
posterior <- extract.samples(flbi)
""";
posterior_R = @rget(posterior);
dat_R = @rget(dat);

@model function mom(N, M, D, B1, B2)
    p ~ Beta(2,2)
    k ~ Exponential(1)
    σ ~ Exponential(1)
    τ ~ Exponential(1)
    a1 ~ Normal(0, 0.5)
    a2 ~ Normal(0, 0.5)
    b ~ Normal(0, 0.5)
    m ~ Normal(0, 0.5)
    U ~ filldist(Normal(0,1), N)
    B1 ~ Bernoulli(p)
    B2 ~ Bernoulli(p)

    ν = a2 .+ b * B2 + m * M + k * U
    D .~ Normal.(ν, τ)

    μ = a1 .+ b * B1 + k * U 
    M .~ Normal.(μ, σ)
end


Turing.setrdcache(true)
Turing.setadbackend(:reversediff)

flbi = sample(mom(Int(dat_R[:N]), dat_R[:M], dat_R[:D], dat_R[:B1], dat_R[:B2]), 
    NUTS(1000, 0.65),
    MCMCThreads(),
    2000, 4)
1 Like

The only thing I could think of here is making the D and M observations MvNormal:

D ~ MvNormal(ν, τ*I)

. . .
M ~ MvNormal(μ, σ*I)

though I’m not sure how much of a performance improvement that gives you.

1 Like

Are you also counting compilation time or only running time?

1 Like

Can you also post the Stan code? Rethinking generates stan code behind the scenes so maybe that can give some hints to performance differences?

1 Like

Set the cache (=true) after you choose reversediff, I believe. Beyond that, I’m not sure.

1 Like

I was unable to run the Stan code, but using MvNormal decreased the run time from 31 to 7 seconds.

1 Like

Finally could figure that out

data{
    int N;
    vector[200] D;
    vector[200] M;
    int B1[200];
    int B2[200];
}
parameters{
    vector[N] U;
    real m;
    real b;
    real a2;
    real a1;
    real<lower=0> tau;
    real<lower=0> sigma;
    real<lower=0> k;
    real<lower=0,upper=1> p;
}
model{
    vector[200] mu;
    vector[200] nu;
    p ~ beta( 2 , 2 );
    k ~ exponential( 1 );
    sigma ~ exponential( 1 );
    tau ~ exponential( 1 );
    a1 ~ normal( 0 , 0.5 );
    a2 ~ normal( 0 , 0.5 );
    b ~ normal( 0 , 0.5 );
    m ~ normal( 0 , 0.5 );
    U ~ normal( 0 , 1 );
    B2 ~ bernoulli( p );
    B1 ~ bernoulli( p );
    for ( i in 1:200 ) {
        nu[i] = a2 + b * B2[i] + m * M[i] + k * U[i];
    }
    D ~ normal( nu , tau );
    for ( i in 1:200 ) {
        mu[i] = a1 + b * B1[i] + k * U[i];
    }
    M ~ normal( mu , sigma );
}

Thanks for the suggestion! That helped a lot!

I hope just running time unless turing recompiles on every call?

As long as you keep the Julia process running in between calls via the REPL / Pluto, it‘s not recompiling on the second call no

1 Like

In that case it should be without recompilation. On the second run I now get

Wall duration     = 21.42 seconds
Compute duration  = 65.88 seconds

What is the Wall duration?

Wall time is the time for all chains to finish and compute duration is the sum of each chain’s time to complete. Did your last benchmark use MvNormal?

1 Like

Yes but I did more draws.

For this model

@model function mom(N, M, D, B1, B2)
    p ~ Beta(2,2)
    k ~ Exponential(1)
    σ ~ Exponential(1)
    τ ~ Exponential(1)
    a1 ~ Normal(0, 0.5)
    a2 ~ Normal(0, 0.5)
    b ~ Normal(0, 0.5)
    m ~ Normal(0, 0.5)
    U ~ filldist(Normal(0,1), N)
    B1 ~ Bernoulli(p)
    B2 ~ Bernoulli(p)

    ν = a2 .+ b * B2 + m * M + k * U
    D ~ MvNormal(ν, τ * I)

    μ = a1 .+ b * B1 + k * U 
    M ~ MvNormal(μ, σ * I)
end

Turing.setadbackend(:reversediff)
Turing.setrdcache(true)

flbi = sample(
    mom(Int(dat_R[:N]), dat_R[:M], dat_R[:D], dat_R[:B1], dat_R[:B2]), 
    NUTS(1000, 0.65),
    MCMCThreads(),
    2_000, 4)

I get

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 10.99 seconds
Compute duration  = 37.53 seconds

Ok. Thanks. I was unable to run the Stan code. How does the new Turing model compare to Stan when the same number of samples are used?

Edit: My mistake. I think that is the comparison with the same number of samples. So it is within a factor of 4?

I actually think that 2000 iterations in stan means you run 2000 iterations but in Turing you get 2000 iterations with 1000 burn-in. So with 3000 iterations in stan I get

All 4 chains finished successfully.
Mean chain execution time: 3.2 seconds.
Total execution time: 3.8 seconds.

There is certainly room for improvement, but 11 seconds is reasonable. In my experience (unless something has changed recently), the problem with Julia’s autodiff is that it does not scale as well as Stan’s. I think the other problem is reversediff does not work well with loops, which can be easier to write in some cases.

1 Like

I’ll try to test a couple of different AD backends. Is it ok to set the backend in a live session or would it be better to restart?

I think it is fine to start in the same session. Unfortunately, reversediff might be the best at the moment.

1 Like

Here is my benchmark for two runs of each:

Turing.setadbackend(:forwarddiff)

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 102.14 seconds
Compute duration  = 361.99 seconds

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 95.9 seconds
Compute duration  = 333.32 seconds
Turing.setadbackend(:tracker)

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 99.84 seconds
Compute duration  = 331.42 seconds

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 86.06 seconds
Compute duration  = 298.54 seconds
Turing.setadbackend(:zygote)
# gave up after a couple of minutes; will let it run over the weekend
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 26.13 seconds
Compute duration  = 101.67 seconds

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 5.85 seconds
Compute duration  = 20.05 seconds
3 Likes

Yep reversediff seems to be the clear winner in my benchmark

1 Like