Lazy vcat of matrices just before multiplication

Dear All,

I am implementing neural network for hyper-multigraph (this is not a joke, the graph has multiple types of edges of arities higher than 3). This is “fairly” straightforward extension of the Mill.jl framework. Now the graph I process has about 10^5 vertices and because I have multiple types of edges, I concatenate output of message pass over each type of edge and then feed it to the dense layer. Looking at the profiler, I have noticed that vertical concatenation of matrices takes some time, therefore I have thought that I would do it lazily.

So to be concrete, I want to compute something like

x = randn(16,105625)
y = randn(16,105625)
z = randn(16,105625)
w = randn(16,3*16)

w * vcat(x,y,z)

without explicitly constructing the output of vcat.

I have found LazyArrays, which do have this functionality, but propagate the lazy computation further and do not lead to speedup, when I use them as Array(w * LazyArrays.Vcat(x,y,z)) (concrete measurements are below).

Finally, I have written something myself as

struct LazyVCat{N,T<:AbstractMatrix}
	xs::NTuple{N,T}
	function LazyVCat(xs::T) where {T}
		n = size(first(xs),2)
		all(n == size(x,2) for x in xs) || error("all matrices should be of the same time")
	end
end

import Base.*
using LinearAlgebra
function *(A::Matrix, B::LazyVCat)
	T = eltype(A)
	o = similar(A, size(A,1), size(first(B.xs),2))
	offset = 0
	for x in B.xs
		LinearAlgebra.mul!(o, view(w, :,offset+1:offset+size(x,1)), x,one(T),one(T))
		offset += size(x,1)
	end
	o
end

which gives me the speedup I was looking for.

The output of benchmarktools is below

using LinearAlgebra, BenchmarkTools, LazyArrays
julia> @benchmark Array(w * LazyArrays.Vcat(x,y,z))
BenchmarkTools.Trial: 195 samples with 1 evaluation.
 Range (min … max):  18.112 ms … 109.953 ms  ┊ GC (min … max):  0.00% … 82.59%
 Time  (median):     25.102 ms               ┊ GC (median):    18.01%
 Time  (mean ± σ):   25.652 ms ±   8.802 ms  ┊ GC (mean ± σ):  21.05% ± 12.97%

                █ ▄        ▄  ▁
  ▇█▅▅▅▆▃▇▄▆▄▁▄▇█▆█▃▆▅▆██▇▇█▆▆█▆█▄▃▄▄▄▆▆▄▄▅▃▃▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▃ ▃
  18.1 ms         Histogram: frequency by time         37.7 ms <

 Memory estimate: 51.58 MiB, allocs estimate: 161.

julia> @benchmark w * vcat(x,y,z)
BenchmarkTools.Trial: 198 samples with 1 evaluation.
 Range (min … max):  17.451 ms … 131.668 ms  ┊ GC (min … max):  0.00% … 83.29%
 Time  (median):     23.815 ms               ┊ GC (median):    20.01%
 Time  (mean ± σ):   25.258 ms ±  11.110 ms  ┊ GC (mean ± σ):  22.70% ± 15.27%

  ▄▇▂     █▂    ▆▃▂
  ███▄▄▅▄▆███▅▅▅███▅▅▅▃▃▃▁▁▁▁▁▁▁▃▃▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▃
  17.5 ms         Histogram: frequency by time           57 ms <

 Memory estimate: 51.57 MiB, allocs estimate: 4.

julia> @benchmark w * L
BenchmarkTools.Trial: 495 samples with 1 evaluation.
 Range (min … max):   8.240 ms … 20.951 ms  ┊ GC (min … max):  0.00% … 38.86%
 Time  (median):      8.735 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   10.085 ms ±  2.322 ms  ┊ GC (mean ± σ):  11.81% ± 15.18%

   ▆█▇▄▁ ▁                            ▃▄▃▂▁
  █████████▇█▇▄▆▅▄▄▁▁▄▄▁▅▄▁▁▁▁▄▁▁▅▁▁▅▅█████▅▇▄▅▅▄▄▁▄▁▁▁▁▄▄▁▄▄ ▇
  8.24 ms      Histogram: log(frequency) by time      16.2 ms <

 Memory estimate: 12.89 MiB, allocs estimate: 23.

Now of course I can finish the implementation by writing a rrule for ChainRules and do all the stuff around. But I do not want to repeat someone else’s work, therefore I want to ask, if someone know about a package, which has implemented this, is fast, and works with AD.

Thanks for answers in advance.

1 Like

You might be looking for:

Perhaps mcabbott will confirm, but it does mention ChainRules compatibility.

This LazyVCat looks sensible to me, modulo typos, I don’t know of a packaged version. Although if you can supply a pre-allocated space for vcat(x,y,z), one call to * may be quicker?

julia> @btime $w * vcat($x,$y,$z);
  min 14.890 ms, mean 16.542 ms (6 allocations, 51.57 MiB)

julia> @btime $w * $L;
  min 4.006 ms, mean 4.769 ms (21 allocations, 12.89 MiB)

julia> @btime $w * $(vcat(x,y,z));
  min 1.509 ms, mean 2.917 ms (3 allocations, 12.89 MiB)
1 Like

Thanks a lot for the answer. I cannot preallocate, as the number of vertices differ between states.