How to implement embeddings in Flux that aren't tragically slow?

So part of the issue is using scalar semantics which lead to a bunch of extra copies, which would be inefficient in reverse mode (but good for Zygote.Forward).

I have implemented a slightly modified version of the same, but I might have confused myself with some dimensions.

using Flux, Flux.Zygote, Flux.Optimise
using LinearAlgebra
using Flux: @functor

struct Embedding{T}
    W::T
end

Embedding(vocab_size::Integer, embedding_size::Integer) = Embedding(randn(Float32, embedding_size, vocab_size))

@functor Embedding

(m::Embedding)(x) = m.W[:, x]

struct DotProduct{T}
    fᵤ::T
    fᵥ::T
end

@functor DotProduct

(m::DotProduct)(x::Tuple{Integer,Integer}) = m.fᵤ(x[1]) ⋅ m.fᵥ(x[2])

(m::DotProduct)(x,y) = sum(m.fᵤ(x) .* m.fᵥ(y))

function main(vocab_length = 10_000, embedding_size = 300)
  encodder = Embedding(vocab_length, embedding_size)
  decodder = Embedding(vocab_length, embedding_size)
  model = DotProduct(encodder, decodder)
  model_zip(x::Integer, y::Integer) = model((x, y))

  opt = Descent()

  function loss(model,
                target_word_index,
                context_word_index,
                negative_sample_word_indices)

    l1 = - sum(log.(sigmoid.(model(target_word_index,
                                   context_word_index))))

    l2 = - sum(log.(sigmoid.(-model(target_word_index,
                                    negative_sample_word_indices))))
    l1 + l2
  end

  @time begin
     for idx in 1:50
        target_idx = rand(1:vocab_length)
        context_idx = rand(1:vocab_length, 16)
        neg_idx = rand(1:vocab_length, 16, 15)

        ps = params(model)

        gs = gradient(ps) do
          l = loss(model, target_idx, context_idx, neg_idx)
        end
        Flux.Optimise.update!(opt, ps, gs)
    end
  end

end

This runs significantly faster for me.

vocal = 10,000

Original:

352.716729 seconds (2.70 M allocations: 1.046 TiB, 7.46% gc time)

Modified:

1.991359 seconds (11.85 k allocations: 3.411 GiB, 5.06% gc time)
7 Likes