How to implement embeddings in Flux that aren't tragically slow?

I have recently started learning Julia and Flux for machine learning. I am trying to implement word2vec as a starting point, but the gradient calculations slow down dramatically with larger vocabulary sizes. With negative sampling, each iteration should be independent of the vocabulary size.

I use two custom layers:

struct Embedding
    W::Array{Float32,2}
end
Embedding(vocab_size::Integer, embedding_size::Integer) = Embedding(randn(vocab_size, embedding_size))
@functor Embedding
(m::Embedding)(x::Integer) = @view m.W[x, :]

struct DotProduct
    fᵤ::Embedding
    fᵥ::Embedding
end
@functor DotProduct
(m::DotProduct)(x::Tuple{Integer,Integer}) = m.fᵤ(x[1]) ⋅ m.fᵥ(x[2])

And then compose the model and loss function

encodder = Embedding(vocab_length, embedding_size)
decodder = Embedding(vocab_length, embedding_size)
model = DotProduct(encodder, decodder)

model_zip(x::Integer, y::Integer) = model((x, y))
loss(target_word_index, context_word_index, negative_sample_word_indices) = (
    - sum(log.(sigmoid.(model_zip.(target_word_index, context_word_index))))
    - sum(log.(sigmoid.(-model_zip.(target_word_index, negative_sample_word_indices))))
)

Calculating the loss is fast, but the gradient is slow. Breaking down the model to test each part and timing with different model sizes, I have found that it is the Embedding layer for large vocabularies that are the slow part.

I would guess that this is because gradients are being calculated for 2 * vocab size * embedding size parameters, instead of just the rows used in the loss function. In a way, negating the speed-up of negative sampling.

I tried to give Flux a hint by specifying the relevant parameters to calculate gradients for:

@time begin
    for idx in 1:50
        target_idx = rand(1:vocab_length)
        context_idx = rand(1:vocab_length, 16)
        neg_idx = rand(1:vocab_length, 16, 15)
        
        e1_relevant_params = e1_params[idx, :]
        e2_relevant_params = e2_params[[unique(neg_idx); unique(context_idx)], :]
        ps = params(e1_relevant_params, e2_relevant_params)
        
        gs = gradient(ps) do
            l = loss(target_idx, context_idx, neg_idx)
        end
        update!(opt, ps, gs)
    end
end

But no luck.
The above code with embedding = 300
vocab = 10,000 : 91.587846 seconds (15.64 M allocations: 573.023 GiB, 10.84% gc time)
vocab = 1,000 : 8.610751 seconds (2.27 M allocations: 57.374 GiB, 13.34% gc time)

So, am I doing something silly? Is this something that can be solved with Flux or would I have to implement the gradients manually to only complete targetted calculations?

Thanks in advance for any help.

1 Like

So part of the issue is using scalar semantics which lead to a bunch of extra copies, which would be inefficient in reverse mode (but good for Zygote.Forward).

I have implemented a slightly modified version of the same, but I might have confused myself with some dimensions.

using Flux, Flux.Zygote, Flux.Optimise
using LinearAlgebra
using Flux: @functor

struct Embedding{T}
    W::T
end

Embedding(vocab_size::Integer, embedding_size::Integer) = Embedding(randn(Float32, embedding_size, vocab_size))

@functor Embedding

(m::Embedding)(x) = m.W[:, x]

struct DotProduct{T}
    fᵤ::T
    fᵥ::T
end

@functor DotProduct

(m::DotProduct)(x::Tuple{Integer,Integer}) = m.fᵤ(x[1]) ⋅ m.fᵥ(x[2])

(m::DotProduct)(x,y) = sum(m.fᵤ(x) .* m.fᵥ(y))

function main(vocab_length = 10_000, embedding_size = 300)
  encodder = Embedding(vocab_length, embedding_size)
  decodder = Embedding(vocab_length, embedding_size)
  model = DotProduct(encodder, decodder)
  model_zip(x::Integer, y::Integer) = model((x, y))

  opt = Descent()

  function loss(model,
                target_word_index,
                context_word_index,
                negative_sample_word_indices)

    l1 = - sum(log.(sigmoid.(model(target_word_index,
                                   context_word_index))))

    l2 = - sum(log.(sigmoid.(-model(target_word_index,
                                    negative_sample_word_indices))))
    l1 + l2
  end

  @time begin
     for idx in 1:50
        target_idx = rand(1:vocab_length)
        context_idx = rand(1:vocab_length, 16)
        neg_idx = rand(1:vocab_length, 16, 15)

        ps = params(model)

        gs = gradient(ps) do
          l = loss(model, target_idx, context_idx, neg_idx)
        end
        Flux.Optimise.update!(opt, ps, gs)
    end
  end

end

This runs significantly faster for me.

vocal = 10,000

Original:

352.716729 seconds (2.70 M allocations: 1.046 TiB, 7.46% gc time)

Modified:

1.991359 seconds (11.85 k allocations: 3.411 GiB, 5.06% gc time)
5 Likes

It may be illuminating to see how Transformers.jl implements embeddings. Gather/scatter are on the table to be added in NNlib, but the implementations aren’t too difficult to parse since they’re mostly simple loops at heart.

2 Likes

Sorry about the slow reply. I’ve been having computer troubles so didn’t have the time to work on this project. This seems to have done the trick though, still having convergence issues but I think that’s something else.

Thanks for all the help