Flux: Embeddings on GPU

Have a look at How to implement embeddings in Flux that aren't tragically slow? - #2 by dhairyagandhi96. Embeddings are an interesting case because they’re trivial to implement as a loop on GPU, but extremely difficult to express as a vectorized computation. If you just want something that works, Transformers.jl/embed.jl at master · chengchingwen/Transformers.jl · GitHub has a working implementation and there’s a PR out to add something like it to NNlib.

1 Like