So part of the issue is using scalar semantics which lead to a bunch of extra copies, which would be inefficient in reverse mode (but good for Zygote.Forward).
I have implemented a slightly modified version of the same, but I might have confused myself with some dimensions.
using Flux, Flux.Zygote, Flux.Optimise
using LinearAlgebra
using Flux: @functor
struct Embedding{T}
W::T
end
Embedding(vocab_size::Integer, embedding_size::Integer) = Embedding(randn(Float32, embedding_size, vocab_size))
@functor Embedding
(m::Embedding)(x) = m.W[:, x]
struct DotProduct{T}
fᵤ::T
fᵥ::T
end
@functor DotProduct
(m::DotProduct)(x::Tuple{Integer,Integer}) = m.fᵤ(x[1]) ⋅ m.fᵥ(x[2])
(m::DotProduct)(x,y) = sum(m.fᵤ(x) .* m.fᵥ(y))
function main(vocab_length = 10_000, embedding_size = 300)
encodder = Embedding(vocab_length, embedding_size)
decodder = Embedding(vocab_length, embedding_size)
model = DotProduct(encodder, decodder)
model_zip(x::Integer, y::Integer) = model((x, y))
opt = Descent()
function loss(model,
target_word_index,
context_word_index,
negative_sample_word_indices)
l1 = - sum(log.(sigmoid.(model(target_word_index,
context_word_index))))
l2 = - sum(log.(sigmoid.(-model(target_word_index,
negative_sample_word_indices))))
l1 + l2
end
@time begin
for idx in 1:50
target_idx = rand(1:vocab_length)
context_idx = rand(1:vocab_length, 16)
neg_idx = rand(1:vocab_length, 16, 15)
ps = params(model)
gs = gradient(ps) do
l = loss(model, target_idx, context_idx, neg_idx)
end
Flux.Optimise.update!(opt, ps, gs)
end
end
end
This runs significantly faster for me.
vocal = 10,000
Original:
352.716729 seconds (2.70 M allocations: 1.046 TiB, 7.46% gc time)
Modified:
1.991359 seconds (11.85 k allocations: 3.411 GiB, 5.06% gc time)