Linear model with categorical variable

Hi,
I’m trying to implement a linear model with a categorical variable in turing but am stuck.
The model has a different μ, depending on whether S=0 or S=1:
Wᵢ ~ Normal(μᵢ, σ)
μᵢ = α_S[i]
α = [α1, α2]
Priors: αᵢ ~ Normal(60, 10)

My implementation looks like this:

@model function linear_model(X, Y, S)
    # Set variance prior
    σ2 ~ Uniform(0.0, 10.0);
    # Set intercept prior
    α1 ~ Normal(60, 10);
    α2 ~ Normal(60, 10);
    # Set slope prior
    β ~ LogNormal(0.0, 1.0);

    # S == 1 means male
    μ1 = α1 .+ β .* (X[S .== 1] .- mean(X[S .== 1])); 
    # S == 0 means female
    μ2 = α2 .+ β .* (X[S .== 0] .- mean(X[S .== 0]));
    μall = zeros(size(S))
    μall[S .== 0] =  μ2[:]
    μall[S .== 1] =  μ1[:]
    W ~ MvNormal(μall, sqrt(σ2));
end

model = linear_model(X, Y, S);
sample(model, NUTS(0.8), 4)

But this throws
ERROR: TypeError: in typeassert, expected Float64, got a value of type ForwardDiff.Dual{Nothing, Float64, 12}
when setting μall[S .== 0] = μ2[:]

Does anyone have a suggestion on how to handle categorical variables?

This allocates an array μall with eltype Float64, but μ1 and μ2 have eltype Dual because of automatic differentiation with ForwardDiff. This would work:

    μall = zeros(size(S), Base.promote_eltype(μ1, μ2))

A better approach would be to compute S .== 0 and S .== 1 outside of the model evaluation. e.g. do

@model function linear_model(X, Y, S, S0_inds = findall(iszero, S), S1_inds = findall(isone, S), S_perm = invperm([S0_inds; S1_inds]))
    ...
    μ1 = @views α1 .+ β .* (X[S1_inds] .- mean(X[S1_inds]));
    μ2 = @views α2 .+ β .* (X[S0_inds] .- mean(X[S0_inds]));
    μall = vcat(μ2, μ1)[S_perm]
    ...

There are probably even simpler ways to do this. e.g. you could center your X values outside the model, since that’s wasted computation for each gradient evaluation. And you should check that this is equivalent to your model.

1 Like

Thanks, that fixed it. The model output appears correct.
Here is the code for the model that shifts the output before-hand:

model function linear_model_cat_shift(X, Y, S, S0_inds = findall(iszero, S), S1_inds = findall(isone, S), S_perm = invperm([S0_inds; S1_inds]))
    # Set variance prior
    σ2 ~ Uniform(0.0, 100.0);
    # Set intercept prior
    α1 ~ Normal(60, 10);
    α2 ~ Normal(60, 10);
    # Set slope prior
    β ~ LogNormal(0.0, 1.0);

    μ1 = α1 .+ β .* X[S1_inds];
    μ2 = α2 .+ β .* X[S0_inds];
    μall = vcat(μ2, μ1)[S_perm]

    Y ~ MvNormal(μall, sqrt(σ2));
end

S0_inds = findall(iszero, S)
S1_inds = findall(isone, S)
X_shift = [X[S0_inds] .- mean(X[S0_inds]); X[S1_inds] .- mean(X[S1_inds])]
X_shift = X_shift[invperm([S0_inds; S1_inds])]

model_noshift = linear_model_cat(X, Y, S);
@benchmark chain_noshift = sample(model_noshift, NUTS(0.8), 1_000)

model_shift = linear_model_cat_shift(X_shift, Y, S);
@benchmark chain_shift = sample(model_shift, NUTS(0.8), 1_000)

And a benchmark evaluation:

BenchmarkTools.Trial: 11 samples with 1 evaluation.
 Range (min … max):  388.787 ms … 615.783 ms  ┊ GC (min … max): 29.16% … 28.47%
 Time  (median):     426.087 ms               ┊ GC (median):    32.25%
 Time  (mean ± σ):   475.907 ms ±  86.439 ms  ┊ GC (mean ± σ):  30.83% ±  1.92%

        █                                    ▃                   
  ▇▇▁▁▁▁█▁▁▁▇▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▇ ▁
  389 ms           Histogram: frequency by time          616 ms <

 Memory estimate: 711.74 MiB, allocs estimate: 903100.

This is virtually identical to the model that centers the X values of the
respective categories inside the model (but this approach runs on avg 10ms slower on my machine)

BenchmarkTools.Trial: 11 samples with 1 evaluation.
 Range (min … max):  403.502 ms … 687.193 ms  ┊ GC (min … max): 30.66% … 32.77%
 Time  (median):     426.057 ms               ┊ GC (median):    29.52%
 Time  (mean ± σ):   465.636 ms ±  83.637 ms  ┊ GC (mean ± σ):  29.97% ±  1.64%

  █ █▁▁     ▁  ▁        ▁ ▁                                   ▁  
  █▁███▁▁▁▁▁█▁▁█▁▁▁▁▁▁▁▁█▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  404 ms           Histogram: frequency by time          687 ms <

 Memory estimate: 732.68 MiB, allocs estimate: 913582.