Make `cat` (on a fixed dimension) type-stable

I posted on Slack but wasn’t able to get a resolution (my fault since I posted at 4AM EST and Slack is so fast moving).

Anyhow, is there a way to make cat type stable? I see the following:

 A = Array{Float64, 3}(reshape(1:2*2*3, (2,2,3)));

@code_warntype cat(zeros(2,2), A; dims=3)

which gives me

Body::Any
 1 ── %1  = (Base.getfield)(#temp#, :dims)::Int64                                                                        │╻       getindex
 │    %2  = (Base.slt_int)(0, 1)::Bool                                                                                   ││╻╷╷╷    iterate
 └───       goto #3 if not %2                                                                                            │││┃│      iterate
 2 ──       goto #4                                                                                                      ││││┃       iterate
 3 ──       invoke Base.getindex(()::Tuple{}, 1::Int64)                                                                  │││││
 └───       $(Expr(:unreachable))                                                                                        │││││
 4 ┄─       goto #5                                                                                                      ││││
 5 ──       goto #6                                                                                                      ││╻       iterate
 6 ──       goto #7                                                                                                      ││
 7 ──       nothing                                                                                                      │
 │    %11 = (getfield)(A, 1)::Array{Float64,2}                                                                           │
 │    %12 = (getfield)(A, 2)::Array{Float64,3}                                                                           │
 │          invoke Core.kwfunc(Base.cat_t::Any)                                                                          ││╻       _cat
 │    %14 = (Base.slt_int)(0, 1)::Bool                                                                                   │││╻╷╷╷╷   #cat_t
 └───       goto #9 if not %14                                                                                           ││││┃│││    isempty
 8 ──       goto #10                                                                                                     │││││┃││     iterate
 9 ──       invoke Base.getindex(()::Tuple{}, 1::Int64)                                                                  ││││││┃│      iterate
 └───       $(Expr(:unreachable))                                                                                        │││││││┃       iterate
 10 ┄       goto #11                                                                                                     │││││││
 11 ─       goto #12                                                                                                     │││││╻       iterate
 12 ─       goto #13                                                                                                     │││││
 13 ─ %22 = invoke Base._cat_t(%1::Int64, Float64::Type, %11::Array{Float64,2}, %12::Vararg{Any,N} where N)::Any         ││││╻       #cat_t#99
 └───       goto #14                                                                                                     ││││
 14 ─       goto #15                                                                                                     │││
 15 ─       goto #16        

Curiously, even though vcat(x...) is semantically the same as cat(x...; dims=1), the latter is type stable:

julia> @code_warntype cat(zeros(2,2,3), A; dims=1)
Body::Any
 1 ── %1  = (Base.getfield)(#temp#, :dims)::Int64                                                                        │╻       getindex
 │    %2  = (Base.slt_int)(0, 1)::Bool                                                                                   ││╻╷╷╷    iterate
 └───       goto #3 if not %2                                                                                            │││┃│      iterate
 2 ──       goto #4                                                                                                      ││││┃       iterate
 3 ──       invoke Base.getindex(()::Tuple{}, 1::Int64)                                                                  │││││
 └───       $(Expr(:unreachable))                                                                                        │││││
 4 ┄─       goto #5                                                                                                      ││││
 5 ──       goto #6                                                                                                      ││╻       iterate
 6 ──       goto #7                                                                                                      ││
 7 ──       nothing                                                                                                      │
 │    %11 = (getfield)(A, 1)::Array{Float64,3}                                                                           │
 │    %12 = (getfield)(A, 2)::Array{Float64,3}                                                                           │
 │          invoke Core.kwfunc(Base.cat_t::Any)                                                                          ││╻       _cat
 │    %14 = (Base.slt_int)(0, 1)::Bool                                                                                   │││╻╷╷╷╷   #cat_t
 └───       goto #9 if not %14                                                                                           ││││┃│││    isempty
 8 ──       goto #10                                                                                                     │││││┃││     iterate
 9 ──       invoke Base.getindex(()::Tuple{}, 1::Int64)                                                                  ││││││┃│      iterate
 └───       $(Expr(:unreachable))                                                                                        │││││││┃       iterate
 10 ┄       goto #11                                                                                                     │││││││
 11 ─       goto #12                                                                                                     │││││╻       iterate
 12 ─       goto #13                                                                                                     │││││
 13 ─ %22 = invoke Base._cat_t(%1::Int64, Float64::Type, %11::Array{Float64,3}, %12::Vararg{Array{Float64,3},N} where N)::Any╻       #cat_t#99
 └───       goto #14                                                                                                     ││││
 14 ─       goto #15                                                                                                     │││
 15 ─       goto #16                                                                                                     ││
 16 ─       return %22                                                                                                   │

versus

julia> @code_warntype vcat(zeros(2,2,3), A)
Body::Array{Float64,3}
1486 1 ─       invoke Core.kwfunc(Base.cat::Any)                                                                            │
     │   %2  = (Base.slt_int)(0, 1)::Bool                                                                                   │╻╷╷╷╷  #cat
     └──       goto #3 if not %2                                                                                            ││┃│││   isempty
     2 ─       goto #4                                                                                                      │││┃││    iterate
     3 ─       invoke Base.getindex(()::Tuple{}, 1::Int64)                                                                  ││││┃│     iterate
     └──       $(Expr(:unreachable))                                                                                        │││││┃      iterate
     4 ┄       goto #5                                                                                                      │││││
     5 ─       goto #6                                                                                                      │││╻      iterate
     6 ─       goto #7                                                                                                      │││
     7 ─ %10 = %new(NamedTuple{(:dims,),Tuple{Val{1}}}, $(QuoteNode(Val{1}())))::NamedTuple{(:dims,),Tuple{Val{1}}}         │││╻╷╷╷   _cat
     │   %11 = invoke Core.kwfunc(Base.cat_t::Any)::getfield(Base, Symbol("#kw##cat_t"))                                    ││││
     │   %12 = Base.cat_t::typeof(Base.cat_t)                                                                               ││││
     │   %13 = invoke %11(%10::NamedTuple{(:dims,),Tuple{Val{1}}}, %12::typeof(Base.cat_t), Float64::Type{Float64}, _2::Array{Float64,3}, _3::Vararg{Array{Float64,3},N} where N)::Array{Float64,3}
     └──       goto #8                                                                                                      ││
     8 ─       return %13                                                                                                   │

What is going on here?

1 is a runtime value and (modulu constant propagation) type inference deals with types. So you get a different type returned based on the runtime value of dims.

1 Like

Slightly more info: @code_warntype assumes that its arguments are not constant… but often you’ll write the dims keyword as a constant. You can see this behavior simply by introspecting a simple wrapper function where you’re written it as the literal 3:

julia> cat3(args...) = cat(args...; dims=3)
cat3 (generic function with 1 method)

julia> @code_warntype cat3(zeros(2,2), A)
Body::Array{Float64,3}
…
2 Likes

Thanks for the help! The answers make a lot of sense. @mbauman I noticed the type instability in a function, where I used dims=3 with cat. How should I understand this, given your example?

Specifically, my code is:

function addones(x :: Array{T, 3} where T <: Real)
    shape = (size(x, 1), size(x, 2))
    return cat(ones(shape), x; dims=3)
end

This function doesn’t seem to be type-stable for me (on 0.7):

julia> @code_warntype cat3(zeros(2,2), A)
Body::Any

What is stable however is using dims=Val(3):

julia> @code_warntype cat(zeros(2,2), A; dims=Val(3))
Body::Array{Float64,3}
1 Like