GPU performance and switching from tabular to recurrent data format for Flux.jl

Indeed, that does work without doing scalar indexing on GPU array…

However, the problem I have in this case is that eachslice(...) returns a (row-)vector and I need the data to be a column-vector as each slice should represent a 1 × sample_size vector. I also tried doing reshape(x, 1, :) but the scalar indexing still happens.

EDIT

Okay, looking at the source code for eachslice, here’s a solution that works without doing scalar indexing and still performs decently for the GPU (~2.2 ms on CPU and ~3 ms on GPU according to the above benchmark).

tabular2rnn(X) = [view(X, i:i, :) for i ∈ 1:size(X, 1)]

Thanks for all the help.

1 Like