How to implement embeddings in Flux that aren't tragically slow?

I have recently started learning Julia and Flux for machine learning. I am trying to implement word2vec as a starting point, but the gradient calculations slow down dramatically with larger vocabulary sizes. With negative sampling, each iteration should be independent of the vocabulary size.

I use two custom layers:

struct Embedding
    W::Array{Float32,2}
end
Embedding(vocab_size::Integer, embedding_size::Integer) = Embedding(randn(vocab_size, embedding_size))
@functor Embedding
(m::Embedding)(x::Integer) = @view m.W[x, :]

struct DotProduct
    fᵤ::Embedding
    fᵥ::Embedding
end
@functor DotProduct
(m::DotProduct)(x::Tuple{Integer,Integer}) = m.fᵤ(x[1]) ⋅ m.fᵥ(x[2])

And then compose the model and loss function

encodder = Embedding(vocab_length, embedding_size)
decodder = Embedding(vocab_length, embedding_size)
model = DotProduct(encodder, decodder)

model_zip(x::Integer, y::Integer) = model((x, y))
loss(target_word_index, context_word_index, negative_sample_word_indices) = (
    - sum(log.(sigmoid.(model_zip.(target_word_index, context_word_index))))
    - sum(log.(sigmoid.(-model_zip.(target_word_index, negative_sample_word_indices))))
)

Calculating the loss is fast, but the gradient is slow. Breaking down the model to test each part and timing with different model sizes, I have found that it is the Embedding layer for large vocabularies that are the slow part.

I would guess that this is because gradients are being calculated for 2 * vocab size * embedding size parameters, instead of just the rows used in the loss function. In a way, negating the speed-up of negative sampling.

I tried to give Flux a hint by specifying the relevant parameters to calculate gradients for:

@time begin
    for idx in 1:50
        target_idx = rand(1:vocab_length)
        context_idx = rand(1:vocab_length, 16)
        neg_idx = rand(1:vocab_length, 16, 15)
        
        e1_relevant_params = e1_params[idx, :]
        e2_relevant_params = e2_params[[unique(neg_idx); unique(context_idx)], :]
        ps = params(e1_relevant_params, e2_relevant_params)
        
        gs = gradient(ps) do
            l = loss(target_idx, context_idx, neg_idx)
        end
        update!(opt, ps, gs)
    end
end

But no luck.
The above code with embedding = 300
vocab = 10,000 : 91.587846 seconds (15.64 M allocations: 573.023 GiB, 10.84% gc time)
vocab = 1,000 : 8.610751 seconds (2.27 M allocations: 57.374 GiB, 13.34% gc time)

So, am I doing something silly? Is this something that can be solved with Flux or would I have to implement the gradients manually to only complete targetted calculations?

Thanks in advance for any help.

So part of the issue is using scalar semantics which lead to a bunch of extra copies, which would be inefficient in reverse mode (but good for Zygote.Forward).

I have implemented a slightly modified version of the same, but I might have confused myself with some dimensions.

using Flux, Flux.Zygote, Flux.Optimise
using LinearAlgebra
using Flux: @functor

struct Embedding{T}
    W::T
end

Embedding(vocab_size::Integer, embedding_size::Integer) = Embedding(randn(Float32, embedding_size, vocab_size))

@functor Embedding

(m::Embedding)(x) = m.W[:, x]

struct DotProduct{T}
    fᵤ::T
    fᵥ::T
end

@functor DotProduct

(m::DotProduct)(x::Tuple{Integer,Integer}) = m.fᵤ(x[1]) ⋅ m.fᵥ(x[2])

(m::DotProduct)(x,y) = sum(m.fᵤ(x) .* m.fᵥ(y))

function main(vocab_length = 10_000, embedding_size = 300)
  encodder = Embedding(vocab_length, embedding_size)
  decodder = Embedding(vocab_length, embedding_size)
  model = DotProduct(encodder, decodder)
  model_zip(x::Integer, y::Integer) = model((x, y))

  opt = Descent()

  function loss(model,
                target_word_index,
                context_word_index,
                negative_sample_word_indices)

    l1 = - sum(log.(sigmoid.(model(target_word_index,
                                   context_word_index))))

    l2 = - sum(log.(sigmoid.(-model(target_word_index,
                                    negative_sample_word_indices))))
    l1 + l2
  end

  @time begin
     for idx in 1:50
        target_idx = rand(1:vocab_length)
        context_idx = rand(1:vocab_length, 16)
        neg_idx = rand(1:vocab_length, 16, 15)

        ps = params(model)

        gs = gradient(ps) do
          l = loss(model, target_idx, context_idx, neg_idx)
        end
        Flux.Optimise.update!(opt, ps, gs)
    end
  end

end

This runs significantly faster for me.

vocal = 10,000

Original:

352.716729 seconds (2.70 M allocations: 1.046 TiB, 7.46% gc time)

Modified:

1.991359 seconds (11.85 k allocations: 3.411 GiB, 5.06% gc time)
2 Likes

It may be illuminating to see how Transformers.jl implements embeddings. Gather/scatter are on the table to be added in NNlib, but the implementations aren’t too difficult to parse since they’re mostly simple loops at heart.

1 Like