# Linear model with categorical variable

Hi,
I’m trying to implement a linear model with a categorical variable in turing but am stuck.
The model has a different μ, depending on whether S=0 or S=1:
Wᵢ ~ Normal(μᵢ, σ)
μᵢ = α_S[i]
α = [α1, α2]
Priors: αᵢ ~ Normal(60, 10)

My implementation looks like this:

``````@model function linear_model(X, Y, S)
# Set variance prior
σ2 ~ Uniform(0.0, 10.0);
# Set intercept prior
α1 ~ Normal(60, 10);
α2 ~ Normal(60, 10);
# Set slope prior
β ~ LogNormal(0.0, 1.0);

# S == 1 means male
μ1 = α1 .+ β .* (X[S .== 1] .- mean(X[S .== 1]));
# S == 0 means female
μ2 = α2 .+ β .* (X[S .== 0] .- mean(X[S .== 0]));
μall = zeros(size(S))
μall[S .== 0] =  μ2[:]
μall[S .== 1] =  μ1[:]
W ~ MvNormal(μall, sqrt(σ2));
end

model = linear_model(X, Y, S);
sample(model, NUTS(0.8), 4)
``````

But this throws
`ERROR: TypeError: in typeassert, expected Float64, got a value of type ForwardDiff.Dual{Nothing, Float64, 12}`
when setting `μall[S .== 0] = μ2[:]`

Does anyone have a suggestion on how to handle categorical variables?

This allocates an array `μall` with eltype `Float64`, but `μ1` and `μ2` have eltype `Dual` because of automatic differentiation with ForwardDiff. This would work:

``````    μall = zeros(size(S), Base.promote_eltype(μ1, μ2))
``````

A better approach would be to compute `S .== 0` and `S .== 1` outside of the model evaluation. e.g. do

``````@model function linear_model(X, Y, S, S0_inds = findall(iszero, S), S1_inds = findall(isone, S), S_perm = invperm([S0_inds; S1_inds]))
...
μ1 = @views α1 .+ β .* (X[S1_inds] .- mean(X[S1_inds]));
μ2 = @views α2 .+ β .* (X[S0_inds] .- mean(X[S0_inds]));
μall = vcat(μ2, μ1)[S_perm]
...
``````

There are probably even simpler ways to do this. e.g. you could center your `X` values outside the model, since that’s wasted computation for each gradient evaluation. And you should check that this is equivalent to your model.

1 Like

Thanks, that fixed it. The model output appears correct.
Here is the code for the model that shifts the output before-hand:

``````model function linear_model_cat_shift(X, Y, S, S0_inds = findall(iszero, S), S1_inds = findall(isone, S), S_perm = invperm([S0_inds; S1_inds]))
# Set variance prior
σ2 ~ Uniform(0.0, 100.0);
# Set intercept prior
α1 ~ Normal(60, 10);
α2 ~ Normal(60, 10);
# Set slope prior
β ~ LogNormal(0.0, 1.0);

μ1 = α1 .+ β .* X[S1_inds];
μ2 = α2 .+ β .* X[S0_inds];
μall = vcat(μ2, μ1)[S_perm]

Y ~ MvNormal(μall, sqrt(σ2));
end

S0_inds = findall(iszero, S)
S1_inds = findall(isone, S)
X_shift = [X[S0_inds] .- mean(X[S0_inds]); X[S1_inds] .- mean(X[S1_inds])]
X_shift = X_shift[invperm([S0_inds; S1_inds])]

model_noshift = linear_model_cat(X, Y, S);
@benchmark chain_noshift = sample(model_noshift, NUTS(0.8), 1_000)

model_shift = linear_model_cat_shift(X_shift, Y, S);
@benchmark chain_shift = sample(model_shift, NUTS(0.8), 1_000)
``````

And a benchmark evaluation:

``````BenchmarkTools.Trial: 11 samples with 1 evaluation.
Range (min … max):  388.787 ms … 615.783 ms  ┊ GC (min … max): 29.16% … 28.47%
Time  (median):     426.087 ms               ┊ GC (median):    32.25%
Time  (mean ± σ):   475.907 ms ±  86.439 ms  ┊ GC (mean ± σ):  30.83% ±  1.92%

█                                    ▃
▇▇▁▁▁▁█▁▁▁▇▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▇ ▁
389 ms           Histogram: frequency by time          616 ms <

Memory estimate: 711.74 MiB, allocs estimate: 903100.
``````

This is virtually identical to the model that centers the X values of the
respective categories inside the model (but this approach runs on avg 10ms slower on my machine)

``````BenchmarkTools.Trial: 11 samples with 1 evaluation.
Range (min … max):  403.502 ms … 687.193 ms  ┊ GC (min … max): 30.66% … 32.77%
Time  (median):     426.057 ms               ┊ GC (median):    29.52%
Time  (mean ± σ):   465.636 ms ±  83.637 ms  ┊ GC (mean ± σ):  29.97% ±  1.64%

█ █▁▁     ▁  ▁        ▁ ▁                                   ▁
█▁███▁▁▁▁▁█▁▁█▁▁▁▁▁▁▁▁█▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
404 ms           Histogram: frequency by time          687 ms <

Memory estimate: 732.68 MiB, allocs estimate: 913582.
``````