Turing model runs with one sampler but not with another

I try the following code and it gives an error

@model tr(l) = begin
    m = length(l)
    x = Array{Union{Float64,Missing},1}(missing,m)
    for i =1:m
        x[i] ~ Normal(0,i)
    end
    return x
end

m = tr([1,2,3])

sample(m,NUTS(0.65),100)
MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})

I change the sampler to PG and it works.

sample(m,PG(20),100)

Can someone help me figure out what is going on? Why can’t I use NUTS or HMC here?

1 Like

This is because you use a strictly typed array x in your model definition, which is not necessary btw. Automatic differentiation fails in this case, this is Turing unrelated.

If you just want to sample from the prior (which is what you do here), you can simply to the following:

@model function test(x)
    for i in eachindex(x)
        x[i] ~ Normal(0, i)
    end
    return x
end

N = 5
samples = test(Array{Union{Float64,Missing},1}(missing, N))()

If you really need to define an array within a model definition, you need to use a special syntax atm. This will ensure that the array is correctly typed during automatic differentiation and that the models stays type-stable at all times.

Here is an example:

@model function test2(x, ::Type{T}=Vector{Float64}) where {T}
    s = T(undef, length(x))
    for i in eachindex(x)
        s[i] ~ truncated(Normal(), 0, Inf)
        x[i] ~ Normal(0, s[i])
    end
    return x
end

This syntax is further described in the docs: Performance Tips

That said, the above examples both do not need the definition of an array at all. The second example can simply be written as follows:

@model function test2(x)
   s ~ filldist(truncated(Normal()), length(x))
   for i in eachindex(x)
       x[i] ~ Normal(0, s[i])
   end
   return x
end

or using broadcasting,

@model function test2(x)
    s ~ filldist(truncated(Normal()), length(x))
    x .~ Normal.(0, s)
end
2 Likes

Thanks Martin! This solved my problem!

Yes this example does not need an array, but I need it for the code I’m writing. I tried to make a minimal example of the error as requested.

1 Like