Query.jl: collect's return type is not as expected

Hello, here’s something I don’t understand about Query.jl: After running a few query commands and calling collect, I was expecting the result to be of type Array{MyObs{Float64}, 1}, but got Array{MyObs,1} instead. When I checked the type of each element of the array, it’s indeed of type MyObs{Float64}. So why is collect not returning Array{MyObs{Float64}, 1}?

Reproducible example:

using LinearAlgebra, Random, StatsModels, Query, DataFrames
struct MyObs{T <: LinearAlgebra.BlasReal}
    y::Vector{T}
    X::Matrix{T} 
    Z::Matrix{T}
    xty::Vector{T} 
    zty::Vector{T}
end
function MyObs(
    y::Vector{T},
    X::Matrix{T},
    Z::Matrix{T}
    ) where T <: LinearAlgebra.BlasReal
    xty = transpose(X) * y
    zty = transpose(Z) * y
    MyObs{T}(y, X, Z, xty, zty)
end
function myobs(data_obs, feformula::FormulaTerm, reformula::FormulaTerm)
    y, X = StatsModels.modelcols(feformula, data_obs)
    Z = StatsModels.modelmatrix(reformula, data_obs)
    return MyObs(y, X, Z)
end 
Random.seed!(1)
reps = 20; N = 10; p = 5; q = 2
X = Matrix{Float64}(undef, N*reps, p)
randn!(X)
rand_intercept = zeros(N*reps)
for j in 1:N
    rand_intercept[(reps * (j-1) + 1) : reps * j] .= Random.randn(1)
end
y = X * ones(p) + rand_intercept + Random.randn(N*reps) 
id = repeat(1:N, inner = reps)
dat = hcat(rename!(DataFrame(hcat(id)), [:id]), DataFrame(hcat(y, X)))
rename!(dat, Symbol.(["id", "y", "x1", "x2", "x3", "x4", "x5"]))
function test_id(subset_id::Vector{T}, x::T, k::Int) where T
    # test whether each element of x is in subset_id
    res = searchsortedfirst(subset_id, x) <= k
    return res
end
k=5; subset_id = [1:1:5;]
feformula   = @formula(y ~ 1 + x1 + x2 + x3 + x4 + x5)
reformula   = @formula(y ~ 1)
feformula = apply_schema(feformula, schema(feformula, dat))
reformula = apply_schema(reformula, schema(reformula, dat))

Then running

obsvec = dat |> @groupby(_.id) |> @filter(test_id(subset_id, key(_), k)) |> @map(myobs(_, feformula, reformula)) |> collect
typeof(obsvec)

We get

Array{MyObs,1}

This is an unfortunate outcome of a reliance on type inference in the EnumerableMap type. It predetermines the output eltype by calling:

T = Base._return_type(f, Tuple{TS,})

where f in this case is essentially your myobs function and Tuple{TS,} is the expected type of the previous operations in the chain (specifically, Tuple{Grouping{Int64,NamedTuple{(:id, :y, :x1, :x2, :x3, :x4, :x5),Tuple{Int64,Float64,Float64,Float64,Float64,Float64,Float64}}}}).

My guess is that the call to myobs(_, feformula, reformula) is just sufficiently complex that the compiler can’t guarantee the return type will be MyObs{Float64}, so the best it can do is MyObs. This might be affected by a # of things, like the transpose code inferrability, StatsModels.modelcols or StatsModels.modelmatrix, FormulaTerm, or just the plain nesting complexity of everything here.

In any case, by calling Base._return_type, it commits the output eltype to whatever the compiler can figure out pre-execution, so when the result is materialized (via collect), it asks whether Base.IteratorEltype is known (in this case yes) and uses that to materialize the output array.

If instead, EnumerableMap defined:

Base.IteratorEltype(::Type{<:EnumerableMap}) = Base.EltypeUnknown()

then a different collect algorithm is used where the output array type is “promoted” as elements are iterated. Which introduces one step of type instability (i.e. at least the initial call to iterate + array allocation), but can lead to more accurate output type. I tested this locally and it indeed returns 5-element Array{MyObs{Float64},1}:.

There are trade-offs between both approaches and even the Base.collect algorithm tries to use a hybrid approach between inspecting Base._return_type and just “growing” the output container. It’s actually one of the more interesting “dynamic” problems that Julia has vs. other languages, IMO, and it’s really interesting to see different approaches and the resulting side effects.

Hope that helps?

1 Like

Thank you very much! This is really comprehensive.
Perhaps it’s also possible to give users control over the output type as in Base.collect?

Yeah, I didn’t think of that, but that’s definitely a work-around here, like:

julia> collect(MyObs{Float64}, obsvec)
5-element Array{MyObs{Float64},1}:

The other idea I had was that you could make your own collect that essentially ignored eltype and only built up the container type as it iterated elements. I think this is somewhat the idea in GitHub - JuliaFolds/BangBang.jl: Immutables as mutables, mutables as immutables., but I haven’t dug into that code very deeply.

1 Like