Syntax for type stability in Turing.jl

Turing.jl recommends to replace

@model tmodel(x, y) = begin
    p,n = size(x)
    params = Vector{Real}(undef, n)
    for i = 1:n
        params[i] ~ truncated(Normal(), 0, Inf)
    end

    a = x * params
    y ~ MvNormal(a, 1.0)
end

with

@model tmodel(x, y, ::Type{T}=Vector{Float64}) where {T} = begin
    ...
    params = T(undef, n)
    ...
end

This looks overly complicated to me… Why not simply:

@model tmodel(x, y) = begin
    ...
    params = Vector{Float64}(undef, n)
    ...
end

This doesn’t allow specifying a non-default type, but neither does the original version?

Hi!

The simple version works too, but you cannot do automatic differentiation (AD) with that model, so you cannot use HMC and NUTS. The “type stability” syntax allows your model to be type stable while being generic in the sense that we can choose T to be an appropriate type for the AD package used. Vector{Real} is also generic because T is always going to be a subtype of Real, e.g. ForwardDiff.Dual, Tracker.TrackedReal or ReverseDiff.TrackedReal. But Vector{Real} is not type stable.

2 Likes

To make this type stable, you’ll want to use a dual cache so that way it can be stable for alternative types. https://docs.sciml.ai/latest/basics/faq/#I-get-Dual-number-errors-when-I-solve-my-ODE-with-Rosenbrock-or-SDIRK-methods-1 explains a type-stable way to handle this.

1 Like

@mohamed82008 thanks! it’s quite clear now.

@ChrisRackauckas is this also useful when declaring the model as recommended (with where {T})?