Dear All,
I am implementing neural network for hyper-multigraph (this is not a joke, the graph has multiple types of edges of arities higher than 3). This is “fairly” straightforward extension of the Mill.jl framework. Now the graph I process has about 10^5 vertices and because I have multiple types of edges, I concatenate output of message pass over each type of edge and then feed it to the dense layer. Looking at the profiler, I have noticed that vertical concatenation of matrices takes some time, therefore I have thought that I would do it lazily.
So to be concrete, I want to compute something like
x = randn(16,105625)
y = randn(16,105625)
z = randn(16,105625)
w = randn(16,3*16)
w * vcat(x,y,z)
without explicitly constructing the output of vcat
.
I have found LazyArrays
, which do have this functionality, but propagate the lazy computation further and do not lead to speedup, when I use them as Array(w * LazyArrays.Vcat(x,y,z))
(concrete measurements are below).
Finally, I have written something myself as
struct LazyVCat{N,T<:AbstractMatrix}
xs::NTuple{N,T}
function LazyVCat(xs::T) where {T}
n = size(first(xs),2)
all(n == size(x,2) for x in xs) || error("all matrices should be of the same time")
end
end
import Base.*
using LinearAlgebra
function *(A::Matrix, B::LazyVCat)
T = eltype(A)
o = similar(A, size(A,1), size(first(B.xs),2))
offset = 0
for x in B.xs
LinearAlgebra.mul!(o, view(w, :,offset+1:offset+size(x,1)), x,one(T),one(T))
offset += size(x,1)
end
o
end
which gives me the speedup I was looking for.
The output of benchmarktools is below
using LinearAlgebra, BenchmarkTools, LazyArrays
julia> @benchmark Array(w * LazyArrays.Vcat(x,y,z))
BenchmarkTools.Trial: 195 samples with 1 evaluation.
Range (min … max): 18.112 ms … 109.953 ms ┊ GC (min … max): 0.00% … 82.59%
Time (median): 25.102 ms ┊ GC (median): 18.01%
Time (mean ± σ): 25.652 ms ± 8.802 ms ┊ GC (mean ± σ): 21.05% ± 12.97%
█ ▄ ▄ ▁
▇█▅▅▅▆▃▇▄▆▄▁▄▇█▆█▃▆▅▆██▇▇█▆▆█▆█▄▃▄▄▄▆▆▄▄▅▃▃▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▃ ▃
18.1 ms Histogram: frequency by time 37.7 ms <
Memory estimate: 51.58 MiB, allocs estimate: 161.
julia> @benchmark w * vcat(x,y,z)
BenchmarkTools.Trial: 198 samples with 1 evaluation.
Range (min … max): 17.451 ms … 131.668 ms ┊ GC (min … max): 0.00% … 83.29%
Time (median): 23.815 ms ┊ GC (median): 20.01%
Time (mean ± σ): 25.258 ms ± 11.110 ms ┊ GC (mean ± σ): 22.70% ± 15.27%
▄▇▂ █▂ ▆▃▂
███▄▄▅▄▆███▅▅▅███▅▅▅▃▃▃▁▁▁▁▁▁▁▃▃▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▃
17.5 ms Histogram: frequency by time 57 ms <
Memory estimate: 51.57 MiB, allocs estimate: 4.
julia> @benchmark w * L
BenchmarkTools.Trial: 495 samples with 1 evaluation.
Range (min … max): 8.240 ms … 20.951 ms ┊ GC (min … max): 0.00% … 38.86%
Time (median): 8.735 ms ┊ GC (median): 0.00%
Time (mean ± σ): 10.085 ms ± 2.322 ms ┊ GC (mean ± σ): 11.81% ± 15.18%
▆█▇▄▁ ▁ ▃▄▃▂▁
█████████▇█▇▄▆▅▄▄▁▁▄▄▁▅▄▁▁▁▁▄▁▁▅▁▁▅▅█████▅▇▄▅▅▄▄▁▄▁▁▁▁▄▄▁▄▄ ▇
8.24 ms Histogram: log(frequency) by time 16.2 ms <
Memory estimate: 12.89 MiB, allocs estimate: 23.
Now of course I can finish the implementation by writing a rrule
for ChainRules and do all the stuff around. But I do not want to repeat someone else’s work, therefore I want to ask, if someone know about a package, which has implemented this, is fast, and works with AD.
Thanks for answers in advance.