Survival analysis in Turing: vectorized version leads to strange errors

I have an example for a simple survival analysis in Turing that works fine when using loop but produces strange errors when I attempt to convert it to vectorized form:

# Import Turing and Distributions.
using Turing, Distributions
using StatsBase
using RDatasets

# Set a seed for reproducibility.
using Random
Random.seed!(0);

# Import the "mastectomy" dataset.
data = RDatasets.dataset("HSAUR", "mastectomy");

transform!(data, :Metastized=>ByRow(x->x=="yes" ? 2 : 1)=>:Metastized, 
                 :Time=>ByRow(log)=>:LogTime)
transform!(data, :LogTime=>(x->standardize(ZScoreTransform,x))=>:LogTimeStd)

# Show the first six rows of our edited dataset.
@info first(data, 6)

# log-normal model (works)
@model function survival_loop(log_time, metastized, event)
    # priors
    σ ~ Exponential(2)
    μ ~ filldist(Normal(),2)

    # fitting data
    for i in eachindex(log_time)
        dist = Normal(μ[metastized[i]], σ)
        if event[i] # not-censored
            log_time[i] ~ dist
        else # censored
            1 ~ Bernoulli(ccdf(dist, log_time[i]))
        end        
    end
end;

# log-normal model, vectorized (doesn't work)
@model function survival_vect(log_time, metastized, event)
    # priors
    σ ~ Exponential(2)
    μ ~ filldist(Normal(),2)

    log_time_died   = log_time[event]
    metastized_died = metastized[event]

    log_time_cens   = log_time[.!event]
    metastized_cens = metastized[.!event]

    log_time_died  .~ Normal.( μ[metastized_died], σ)

    censored = ones(Int,length(log_time_cens))
    
    censored .~ Bernoulli.(ccdf.(Logistic.( μ[metastized_cens], σ), log_time_cens))
end


# loop mode (works)
model_loop = survival_loop(data.LogTimeStd, data.Metastized, data.Event) # , 1.0,0.1
chain_loop = sample(model_loop, NUTS(0.80), MCMCSerial(), 1_000, 1)
show(IOContext(stdout, :limit => false), "text/plain",chain_loop)

# vectorized mode (doesn't work)
model_vect = survival_vect(data.LogTimeStd, data.Metastized, data.Event) # , fill(1,sum(.!data.Event)) 
chain_vect = sample(model_vect, NUTS(0.80), MCMCSerial(), 1_000, 1)
show(IOContext(stdout, :limit => false), "text/plain",chain_vect)

Using R version 1.6.7 and following packages:

  [31c24e10] Distributions v0.25.68
  [ce6b1742] RDatasets v0.7.7
  [2913bbd2] StatsBase v0.33.21
  [fce5fe82] Turing v0.21.10
  [9a3f8284] Random

And here are the error messages (first few):

ERROR: LoadError: TypeError: in typeassert, expected Float64, got a value of type ForwardDiff.Dual{Nothing, Float64, 12}
Stacktrace:
  [1] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}, i1::Int64)
    @ Base ./array.jl:843
  [2] _unsafe_copyto!(dest::Vector{Float64}, doffs::Int64, src::Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, soffs::Int64, n::Int64)
    @ Base ./array.jl:235
  [3] unsafe_copyto!
    @ ./array.jl:289 [inlined]
  [4] _copyto_impl!
    @ ./array.jl:313 [inlined]
  [5] copyto!
    @ ./array.jl:299 [inlined]
  [6] copyto!
    @ ./array.jl:325 [inlined]
  [7] copyto!
    @ ./broadcast.jl:977 [inlined]
  [8] copyto!
    @ ./broadcast.jl:936 [inlined]
  [9] materialize!
    @ ./broadcast.jl:894 [inlined]
 [10] materialize!
    @ ./broadcast.jl:891 [inlined]
 [11] survival_vect(__model__::DynamicPPL.Model{typeof(survival_vect), (:log_time, :metastized, :event), (), (), Tuple{Vector{Float64}, Vector{Int64}, Vector{Bool}}, Tuple{}, DynamicPPL.DefaultContext}, __varinfo__::DynamicPPL.TypedVarInfo{NamedTuple{(:σ, :μ, :log_time_died, :censored), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ, Setfield.IdentityLens}, Int64}, Vector{Exponential{Float64}}, Vector{AbstractPPL.VarName{:σ, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:μ, Setfield.IdentityLens}, Int64}, Vector{DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}}, Vector{AbstractPPL.VarName{:μ, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:log_time_died, Setfield.IndexLens{Tuple{Int64}}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:log_time_died, Setfield.IndexLens{Tuple{Int64}}}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:censored, Setfield.IndexLens{Tuple{Int64}}}, Int64}, Vector{Bernoulli{Float64}}, Vector{AbstractPPL.VarName{:censored, Setfield.IndexLens{Tuple{Int64}}}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, Vector{Set{DynamicPPL.Selector}}}}}, ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, __context__::DynamicPPL.SamplingContext{DynamicPPL.Sampler{NUTS{Turing.Essential.ForwardDiffAD{0, standardtag} where standardtag, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext, Random._GLOBAL_RNG}, log_time::Vector{Float64}, metastized::Vector{Int64}, event::Vector{Bool})
    @ Main ~/src/hdd_survival/example/surv_tutorial_simple.jl:49

Heya whilst playing around with this I stumbled across this issue. So I got the above working by explicit conversion to Real:

@model function survival_vect(log_time, metastized, event)
    # priors
    σ ~ Exponential(2)
    μ ~ filldist(Normal(),2)

    log_time_died   = convert(Vector{Real}, log_time[event])
    metastized_died = convert(Vector{Real}, metastized[event])

    log_time_cens   = convert(Vector{Real}, log_time[.!event])
    metastized_cens = convert(Vector{Real}, metastized[.!event])

    log_time_died  .~ Normal.( μ[metastized_died], σ)

    censored = ones(Real,length(log_time_cens))
    
    censored .~ Bernoulli.(ccdf.(Logistic.( μ[metastized_cens], σ), log_time_cens))
end

1 Like

Surprisingly, this code works, but it now treats all arrays as unknown parameters needed to be estimated, i.e the output of the loop version:

Chains MCMC chain (1000×15×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 16.93 seconds
Compute duration  = 16.93 seconds
parameters        = σ, μ[1], μ[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64 

           σ    1.4617    0.2195     0.0069    0.0097   622.7722    0.9992       36.7830
        μ[1]    1.0566    0.4447     0.0141    0.0161   571.5087    1.0004       33.7552
        μ[2]    0.1397    0.2641     0.0084    0.0088   833.1433    1.0040       49.2082
...

And the output of vectorized:

Chains MCMC chain (1000×59×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 22.66 seconds
Compute duration  = 22.66 seconds
parameters        = σ, μ[1], μ[2], log_time_died[1], log_time_died[2], log_time_died[3], log_time_died[4], log_time_died[5], log_time_died[6], log_time_died[7], log_time_died[8], log_time_died[9], log_time_died[10], log_time_died[11], log_time_died[12], log_time_died[13], log_time_died[14], log_time_died[15], log_time_died[16], log_time_died[17], log_time_died[18], log_time_died[19], log_time_died[20], log_time_died[21], log_time_died[22], log_time_died[23], log_time_died[24], log_time_died[25], log_time_died[26], censored[1], censored[2], censored[3], censored[4], censored[5], censored[6], censored[7], censored[8], censored[9], censored[10], censored[11], censored[12], censored[13], censored[14], censored[15], censored[16], censored[17], censored[18]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
         parameters      mean       std   naive_se      mcse       ess      rhat   ess_per_sec 
             Symbol   Float64   Float64    Float64   Float64   Float64   Float64       Float64 

                  σ    0.3740    0.0000     0.0000    0.0000       NaN       NaN           NaN
               μ[1]    0.4496    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
               μ[2]    1.0286    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
   log_time_died[1]   -0.6765    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
   log_time_died[2]    0.0074    0.0000     0.0000    0.0000       NaN       NaN           NaN
   log_time_died[3]   -1.3043    0.0000     0.0000    0.0000       NaN       NaN           NaN
   log_time_died[4]    1.1129    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
   log_time_died[5]    0.5471    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
   log_time_died[6]   -0.8018    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
   log_time_died[7]   -1.3561    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
   log_time_died[8]   -1.5322    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
   log_time_died[9]   -0.9897    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[10]   -1.6721    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[11]   -1.4060    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[12]   -0.9463    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[13]    0.9635    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[14]    0.6761    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[15]   -1.4062    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[16]   -1.6539    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[17]   -1.5469    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[18]    1.5997    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[19]   -1.1834    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[20]    0.0960    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[21]   -1.9447    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[22]   -1.2405    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[23]   -1.9827    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[24]   -0.0774    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
  log_time_died[25]    0.0301    0.0000     0.0000    0.0000       NaN       NaN           NaN
  log_time_died[26]   -1.9126    0.0000     0.0000    0.0000    2.6720    0.9990        0.1179
        censored[1]    1.0000    0.0000     0.0000    0.0000       NaN       NaN           NaN
        censored[2]    1.0000    0.0000     0.0000    0.0000       NaN       NaN           NaN
        censored[3]    0.0000    0.0000     0.0000    0.0000       NaN       NaN           NaN
        censored[4]    0.0000    0.0000     0.0000    0.0000       NaN       NaN           NaN
...

P.S. I just noticed that my vectorized version had an error, the distribution for censored data points was specified as logistic, instead of Normal. It doesn’t change the strange error though.

Played around a bit and fixed that issue, I think it’s basically because when it reads ~ outside a for loop it’s assumed to be a parameter to fit. So for the first part I solved it by just removing the reference to a new variable name, for the second part by passing it as an argument to the function (though maybe you’ll find this too hacky?) BTW I checked the @time and the vector seems longer on my machine with way more allocations, maybe because of additional arguments?

@model function survival_vect(log_time, metastized, event, ones)
    # priors
    σ ~ Exponential(2)
    μ ~ filldist(Normal(),2)

    log_time = convert(Vector{Real}, log_time)
    metastized = convert(Vector{Real}, metastized)

    log_time[event]  .~ Normal.( μ[metastized[event]], σ)
    
    ones .~ Bernoulli.(ccdf.(Normal.( μ[metastized[.!event]], σ), log_time[.!event]))
end

model_vect = survival_vect(data.LogTimeStd, data.Metastized, data.Event, ones(Real, sum(data.Event.==0)))
chain_vect = sample(model_vect, NUTS(0.80), MCMCSerial(), 1_000, 1)
show(IOContext(stdout, :limit => false), "text/plain",chain_vect)

Following code works, and seem to be faster then the loop one:

@model function survival_vect(log_time_died, metastized_died, log_time_censored, metastized_censored, censored_ones)
    # priors
    σ ~ Exponential(2)
    μ ~ filldist(Normal(),2)

    log_time_died .~ Normal.( μ[metastized_died], σ)
    censored_ones .~ Bernoulli.(ccdf.(Normal.( μ[metastized_censored], σ), log_time_censored))
end

# vectorized mode 
model_vect = survival_vect(data.LogTimeStd[data.Event], data.Metastized[data.Event], 
                           data.LogTimeStd[.!data.Event], data.Metastized[.!data.Event], 
                           ones(Int64, sum(.!data.Event)))

chain_vect = sample(model_vect, NUTS(0.80), MCMCSerial(), 1_000, 1)
show(IOContext(stdout, :limit => false), "text/plain",chain_vect)

The run output for loop mode:

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 15.71 seconds
Compute duration  = 15.71 seconds
parameters        = σ, μ[1], μ[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64 

           σ    1.4617    0.2195     0.0069    0.0097   622.7722    0.9992       39.6367
        μ[1]    1.0566    0.4447     0.0141    0.0161   571.5087    1.0004       36.3740
        μ[2]    0.1397    0.2641     0.0084    0.0088   833.1433    1.0040       53.0259

Run output for vectorized mode:

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 4.25 seconds
Compute duration  = 4.25 seconds
parameters        = σ, μ[1], μ[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64 

           σ    1.4872    0.2411     0.0076    0.0074   584.1241    1.0013      137.3763
        μ[1]    1.0578    0.4511     0.0143    0.0166   727.5249    0.9996      171.1018
        μ[2]    0.1375    0.2663     0.0084    0.0095   748.2147    0.9990      175.9677
1 Like

Yeah that’s nice, just bring everything outside the function other than the essentials!