GPU performance and switching from tabular to recurrent data format for Flux.jl

Thanks!

Using

tabular2rnn(X) = [permutedims(x) for x ∈ eachslice(X, dims=1)]

instead of what I was using reduces the CPU runspeed to roughly 2.5 ms and the GPU runspeed to 3.4 ms. Quite the improvement. If I find some inspiration on how to improve it even more using your links I’ll post an update here, until then, this is a good solution!

1 Like