I have recently started learning Julia and Flux for machine learning. I am trying to implement word2vec as a starting point, but the gradient calculations slow down dramatically with larger vocabulary sizes. With negative sampling, each iteration should be independent of the vocabulary size.

I use two custom layers:

```
struct Embedding
W::Array{Float32,2}
end
Embedding(vocab_size::Integer, embedding_size::Integer) = Embedding(randn(vocab_size, embedding_size))
@functor Embedding
(m::Embedding)(x::Integer) = @view m.W[x, :]
struct DotProduct
fᵤ::Embedding
fᵥ::Embedding
end
@functor DotProduct
(m::DotProduct)(x::Tuple{Integer,Integer}) = m.fᵤ(x[1]) ⋅ m.fᵥ(x[2])
```

And then compose the model and loss function

```
encodder = Embedding(vocab_length, embedding_size)
decodder = Embedding(vocab_length, embedding_size)
model = DotProduct(encodder, decodder)
model_zip(x::Integer, y::Integer) = model((x, y))
loss(target_word_index, context_word_index, negative_sample_word_indices) = (
- sum(log.(sigmoid.(model_zip.(target_word_index, context_word_index))))
- sum(log.(sigmoid.(-model_zip.(target_word_index, negative_sample_word_indices))))
)
```

Calculating the loss is fast, but the gradient is slow. Breaking down the model to test each part and timing with different model sizes, I have found that it is the Embedding layer for large vocabularies that are the slow part.

I would guess that this is because gradients are being calculated for 2 * vocab size * embedding size parameters, instead of just the rows used in the loss function. In a way, negating the speed-up of negative sampling.

I tried to give `Flux`

a hint by specifying the relevant parameters to calculate gradients for:

```
@time begin
for idx in 1:50
target_idx = rand(1:vocab_length)
context_idx = rand(1:vocab_length, 16)
neg_idx = rand(1:vocab_length, 16, 15)
e1_relevant_params = e1_params[idx, :]
e2_relevant_params = e2_params[[unique(neg_idx); unique(context_idx)], :]
ps = params(e1_relevant_params, e2_relevant_params)
gs = gradient(ps) do
l = loss(target_idx, context_idx, neg_idx)
end
update!(opt, ps, gs)
end
end
```

But no luck.

The above code with embedding = 300

vocab = 10,000 : `91.587846 seconds (15.64 M allocations: 573.023 GiB, 10.84% gc time)`

vocab = 1,000 : `8.610751 seconds (2.27 M allocations: 57.374 GiB, 13.34% gc time)`

So, am I doing something silly? Is this something that can be solved with Flux or would I have to implement the gradients manually to only complete targetted calculations?

Thanks in advance for any help.