Avoid writing exponentially many methods for a function that may take vectors or matrices

Hi, in my current project I need to compute a function, say f, in N inputs, which all may be either vectors or matrices. In terms of interpretation: a vector is an ‘observation’, and a matrix is a set of observations, where each row is one observation. If I have a mixture of vectors and matrices, I wish to “extend” every vector to a matrix, to make sure the sizes agree (without actually doing this, because this would take unnecessarily many allocations and memory usage). In every case except for the “all vector” case, the output type is a matrix. In the “all vector” case, the output type is a vector.

What I wish to avoid is having to write a specific method for each of the 2^N possible combinations of types. Two clear approaches for this came to my mind:

  1. Write a method for “all vector”, and then write a separate method for “the rest” which uses broadcasting. This is hard, as I would need to determine what size output matrix to initialise and what values to extract (entry d in a vector, column d in a matrix).
  2. Write a method for “all vector” and a method for “all matrix”. Then, whenever I need to call the function with a mix of vectors and matrices, first make each vector into a matrix using repeat_vec_fixed (defined in the MWE), after which the method for “all matrix” is used.

I was wondering if there is some elegant way of performing 1, or a way of doing 2 without actually turning each vector into a matrix first. Or perhaps there is some other method I’m not aware of?

Please see the below for a MWE to illustrate exactly what I mean

using BenchmarkTools, Distributions, Random
Random.seed!(42)

function f(x::AbstractVector{T}, y::AbstractVector{U}, z::AbstractVector{V}) where {T<:Real,U<:Real,V<:Real}
    vec_out = Vector{promote_type(T, U, V)}(undef, length(x))

    vec_out[1] = x[1] + y[1] + z[1]
    vec_out[2] = x[2] * z[2]
    vec_out[3] = x[3] * y[2]
    return vec_out
end

function f(x::AbstractMatrix{T}, y::AbstractMatrix{U}, z::AbstractMatrix{V}) where {T<:Real,U<:Real,V<:Real}
    mat_out = Matrix{promote_type(T, U, V)}(undef, size(x))

    @views mat_out[:, 1] .= x[:, 1] .+ y[:, 1] .+ z[:, 1]
    @views mat_out[:, 2] .= x[:, 2] .* z[:, 2]
    @views mat_out[:, 3] .= x[:, 3] .* y[:, 2]
    return mat_out
end

function repeat_vec_fixed(a::AbstractVector{T}, n::Int64) where {T}
    mat = Matrix{T}(undef, n, length(a))
    for (i, ai) in enumerate(a)
        mat[:, i] .= ai
    end
    return mat
end

function test(N)
    x_test = rand(N, 3)
    y_test = rand(N, 3)
    z_test = rand(N, 2)

    # Compute f on the test data
    display(f(x_test, y_test, z_test))
    # Compute f on the test data, iterating over rows and then stacking
    display(stack([f(x, y, z) for (x, y, z) in zip(eachrow(x_test), eachrow(y_test), eachrow(z_test))], dims = 1))
    
    z_test_2 = rand(2)
    # Compute f on the test data, but now we have one row z, which we wish to treat as a matrix:
    display(f(x_test, y_test, repeat_vec_fixed(z_test_2, N)))
end

# The structure that a more general method would have: 
function f(x, y, z)
    # First problem: we can only get the number of rows of this matrix by getting the number of rows from the non-vector entries, and then we need to get the number of columns by taking the length of x (if x is a vector) or the number of columns of x (if x is a matrix)
    mat_out = Matrix{promote_type(eltype(x), eltype(y), eltype(z))}(undef, size(x))

    mat_out[:, 1] = NaN # Second problem: how to do this? I can't use broadcasting immediately, as I don't know if x/y/z are vectors or matrices
    mat_out[:, 2] = NaN # Same
    mat_out[:, 3] = NaN # Same
    return mat_out
end

test(10)

A quick fix is to reshape vectors into matrices, that should incur little or no cost:

promote_shape(v::AbstractVector) = reshape(v, 1, :)
promote_shape(a::AbstractMatrix) = a

Is there any particular reason why this is unappealing to you?

3 Likes

No, I just didn’t think about reshaping… I did think about writing two get_elem_index functions as below, but that did not seem too appealing because they would pop up everywhere in my code. Your approach seems good for my purposes as I can just write a general f(x, y, z) that calls f(promote_shape(x), promote_shape(y), promote_shape(z)), and dispatch should be able to avoid unnecessary calls to promote_shape if I call it with “all vector” or “all matrix”! Thanks!

get_elem_index(v::AbstractVector, d::Int64) = v[d]
get_elem_index(a::AbstractMatrix, d::Int64) = view(a, :, d)
1 Like

BTW, there is already a promote_shape function in Base, which I was not aware of! It does something a bit similar, but not quite, so maybe consider a different name.

2 Likes