Efficiently remove diagonal from Array

I can confirm that Yakir is not a total lunatic.

A quick benchmark:

# Original function
57.002 ms (19002 allocations: 75.87 MiB)
# nilshg function
  14.625 ms (2 allocations: 7.62 MiB)
# Yakir 1
  14.943 ms (2 allocations: 7.62 MiB)
# Yakir 2
  7.595 ms (8 allocations: 7.64 MiB)
# Yakir 3
  20.609 ms (13 allocations: 23.82 MiB)

Note that the second function that Yakir proposed returns an adjoint rather than an array - in all likelihood this is fine but depending on your use case you might have to collect it into a regular array, which can wipe out the gains:

# Yakir 2 plus collect
16.184 ms (10 allocations: 15.26 MiB)

Full benchmark code below for people to check whether I’ve screwed up the benchmarking somehow:


function remove_diagonal(x)
    mat = Array{Float64}(undef, size(x, 1), size(x, 1) - 1)
    for i = 1:size(x, 1)
        inds = setdiff(1:size(x, 1), i)
        mat[i, :] .= x[i, inds]
    end
    return mat
end


function remove_diagonal2(x)
    mat = Array{Float64}(undef, size(x, 1), size(x, 1) - 1)
    for i = 1:size(mat, 1)
        for j ∈ 1:size(mat, 2)
            if i > j
                mat[i, j] = x[i, j]
            else
                mat[i, j] = x[i, j+1]
            end
        end    
    end
    return mat
end


function remove_diagonal3(x)
    
    b = Vector{Vector{eltype(a)}}(undef, size(x, 2))
    for (i,c) in enumerate(eachrow(x))
        c2 = copy(c)
        deleteat!(c2, i)
        b[i] = c2
    end
    hcat(b...)'
end

function remove_diagonal4(x)
    sz = LinearIndices(x)
    n = size(x, 1)
    k = [sz[i,i] for i in 1:n]
    b = collect(vec(x'))
    deleteat!(b, k)
    reshape(b, n - 1, n)'
end

rd4_c(x) = collect(remove_diagonal4(x))

function remove_diagonal5(x)
    n = size(x, 1)
    i = CartesianIndices(x)
    k = filter(x -> x.I[1] ≠ x.I[2], i)
    b = reshape(x'[k], n - 1, n)'
end

using BenchmarkTools
xl = rand(1000, 1000)

println("Original function")
@btime remove_diagonal($xl)

println("nilshg function")
@btime remove_diagonal2($xl)

println("Yakir 1")
@btime remove_diagonal3($xl)

println("Yakir 2")
@btime remove_diagonal4($xl)

println("Yakir 3")
@btime remove_diagonal5($xl)

println("Yakir 2 plus collect")
@btime rd4_c($xl)
5 Likes