Iteration protocol and multidimensional arrays

#1

Hello,

I really like the new iteration protocol, and I’m happy to see it being spread across packages. However, I usually have the problem, that I work with multidimensional arrays that is often not supported directly. In the MWE below, I use gseidel! from IterativeSolvers.jl. The method accepts 1-and 2-d arrays but mine are 2-and 3d arrays, thus I need to copy and slice arrays to fit within a for loop.

Are there more idiomatic ways of doing the same?

using IterativeSolvers

n = 10
m = 3
A = rand(n,n,m)
b = rand(n,m)
x1 = zeros(n)

# 2d
@time gauss_seidel!(x1, A[:,:,1], b[:,1]; maxiter=10)

# 3d
x2 = zeros(n,m)
xt = zeros(n)
@time for i = 1:m
    xt .= x2[:,i]
    gauss_seidel!(xt, A[:,:,i], b[:,i]; maxiter=10)
    x2[:,i] .= xt
end

x1 == x2[:,1]

Thanks!

EDIT:
Benchmarking gives 3 and 27 allocations, respectively.

0 Likes

#2

You may be able to just write foreach(gauss_seidel!, eachcol(x2), eachslice(A,dims=3),eachcol(b)) here. Or use (x...) -> gauss_seidel!(x...; maxiter=10) to insert the keyword.

eachslice and friends make views, instead of b[:,i] which are copies. Υour “3d” way could perhaps just write into view(x2,:,i) instead of messing with xt; it may also be faster to change all of them (by putting @views in front).

0 Likes

#3

Thanks for the tips! Benchmarking shows best result with a simple @views. So unless I’ve made a mistake in benchmarking, that’s a fairly simple implementation…

using IterativeSolvers
using BenchmarkTools

function bar(x,A,b)
    gauss_seidel!(x, A, b)
end

function foo(x,A,b)
    foreach((x...) -> gauss_seidel!(x...; maxiter=10), eachcol(x), eachslice(A,dims=3),eachcol(b))
end

function baz(x,A,b)
    n,m = size(x)
    for i = 1:m
        @views gauss_seidel!(x[:,i], A[:,:,i], b[:,i]; maxiter=10)
    end
end

n = 10
m = 3
A = rand(n,n,m)
b = rand(n,m)

x1 = zeros(n)
x2 = zeros(n,m)
x3 = zeros(n,m)

bar(x1,A[:,:,1],b[:,1])
foo(x2,A,b)
baz(x3,A,b)

x1 == x2[:,1] == x3[:,1]
x2 == x3

test1() = @btime bar($(zeros(n)),$(A[:,:,1]),$(b[:,1]));
test2() = @btime foo($(zeros(n,m)),$A,$b);
test3() = @btime baz($(zeros(n,m)),$A,$b);

test1()
test2()
test3()

Output:

1.298 μs (1 allocation: 48 bytes) #bar
7.976 μs (36 allocations: 1.28 KiB) #foo
4.750 μs (12 allocations: 624 bytes) #baz
0 Likes