I have an example for a simple survival analysis in Turing that works fine when using loop but produces strange errors when I attempt to convert it to vectorized form:
# Import Turing and Distributions.
using Turing, Distributions
using StatsBase
using RDatasets
# Set a seed for reproducibility.
using Random
Random.seed!(0);
# Import the "mastectomy" dataset.
data = RDatasets.dataset("HSAUR", "mastectomy");
transform!(data, :Metastized=>ByRow(x->x=="yes" ? 2 : 1)=>:Metastized,
:Time=>ByRow(log)=>:LogTime)
transform!(data, :LogTime=>(x->standardize(ZScoreTransform,x))=>:LogTimeStd)
# Show the first six rows of our edited dataset.
@info first(data, 6)
# log-normal model (works)
@model function survival_loop(log_time, metastized, event)
# priors
σ ~ Exponential(2)
μ ~ filldist(Normal(),2)
# fitting data
for i in eachindex(log_time)
dist = Normal(μ[metastized[i]], σ)
if event[i] # not-censored
log_time[i] ~ dist
else # censored
1 ~ Bernoulli(ccdf(dist, log_time[i]))
end
end
end;
# log-normal model, vectorized (doesn't work)
@model function survival_vect(log_time, metastized, event)
# priors
σ ~ Exponential(2)
μ ~ filldist(Normal(),2)
log_time_died = log_time[event]
metastized_died = metastized[event]
log_time_cens = log_time[.!event]
metastized_cens = metastized[.!event]
log_time_died .~ Normal.( μ[metastized_died], σ)
censored = ones(Int,length(log_time_cens))
censored .~ Bernoulli.(ccdf.(Logistic.( μ[metastized_cens], σ), log_time_cens))
end
# loop mode (works)
model_loop = survival_loop(data.LogTimeStd, data.Metastized, data.Event) # , 1.0,0.1
chain_loop = sample(model_loop, NUTS(0.80), MCMCSerial(), 1_000, 1)
show(IOContext(stdout, :limit => false), "text/plain",chain_loop)
# vectorized mode (doesn't work)
model_vect = survival_vect(data.LogTimeStd, data.Metastized, data.Event) # , fill(1,sum(.!data.Event))
chain_vect = sample(model_vect, NUTS(0.80), MCMCSerial(), 1_000, 1)
show(IOContext(stdout, :limit => false), "text/plain",chain_vect)
Using R version 1.6.7 and following packages:
[31c24e10] Distributions v0.25.68
[ce6b1742] RDatasets v0.7.7
[2913bbd2] StatsBase v0.33.21
[fce5fe82] Turing v0.21.10
[9a3f8284] Random
And here are the error messages (first few):
ERROR: LoadError: TypeError: in typeassert, expected Float64, got a value of type ForwardDiff.Dual{Nothing, Float64, 12}
Stacktrace:
[1] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}, i1::Int64)
@ Base ./array.jl:843
[2] _unsafe_copyto!(dest::Vector{Float64}, doffs::Int64, src::Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, soffs::Int64, n::Int64)
@ Base ./array.jl:235
[3] unsafe_copyto!
@ ./array.jl:289 [inlined]
[4] _copyto_impl!
@ ./array.jl:313 [inlined]
[5] copyto!
@ ./array.jl:299 [inlined]
[6] copyto!
@ ./array.jl:325 [inlined]
[7] copyto!
@ ./broadcast.jl:977 [inlined]
[8] copyto!
@ ./broadcast.jl:936 [inlined]
[9] materialize!
@ ./broadcast.jl:894 [inlined]
[10] materialize!
@ ./broadcast.jl:891 [inlined]
[11] survival_vect(__model__::DynamicPPL.Model{typeof(survival_vect), (:log_time, :metastized, :event), (), (), Tuple{Vector{Float64}, Vector{Int64}, Vector{Bool}}, Tuple{}, DynamicPPL.DefaultContext}, __varinfo__::DynamicPPL.TypedVarInfo{NamedTuple{(:σ, :μ, :log_time_died, :censored), Tuple{DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:σ, Setfield.IdentityLens}, Int64}, Vector{Exponential{Float64}}, Vector{AbstractPPL.VarName{:σ, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:μ, Setfield.IdentityLens}, Int64}, Vector{DistributionsAD.TuringScalMvNormal{Vector{Float64}, Float64}}, Vector{AbstractPPL.VarName{:μ, Setfield.IdentityLens}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:log_time_died, Setfield.IndexLens{Tuple{Int64}}}, Int64}, Vector{Normal{Float64}}, Vector{AbstractPPL.VarName{:log_time_died, Setfield.IndexLens{Tuple{Int64}}}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, Vector{Set{DynamicPPL.Selector}}}, DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:censored, Setfield.IndexLens{Tuple{Int64}}}, Int64}, Vector{Bernoulli{Float64}}, Vector{AbstractPPL.VarName{:censored, Setfield.IndexLens{Tuple{Int64}}}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, Vector{Set{DynamicPPL.Selector}}}}}, ForwardDiff.Dual{ForwardDiff.Tag{Turing.TuringTag, Float64}, Float64, 12}}, __context__::DynamicPPL.SamplingContext{DynamicPPL.Sampler{NUTS{Turing.Essential.ForwardDiffAD{0, standardtag} where standardtag, (), AdvancedHMC.DiagEuclideanMetric}}, DynamicPPL.DefaultContext, Random._GLOBAL_RNG}, log_time::Vector{Float64}, metastized::Vector{Int64}, event::Vector{Bool})
@ Main ~/src/hdd_survival/example/surv_tutorial_simple.jl:49