I haven’t checked this carefully, but I think that we-writing this not to make slices at all can get us to 2ms:
julia> function loss(
Wmid, Wctx,
tok_mid, tok_ctx, x
)
tmp = sum(Wmid[tok_mid, :] .* Wctx[tok_ctx, :]; dims=2) |> vec
-mean(@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp)))
end
loss (generic function with 1 method)
julia> grad = @btime train(Random.Xoshiro(5), 100277, 100, 2);
1.892 ms (274 allocations: 4.62 MiB)
julia> using Tullio # first way I thought of
julia> function loss(
Wmid, Wctx,
tok_mid, tok_ctx, x
)
@tullio tmp[k] := Wmid[tok_mid[k], c] * Wctx[tok_ctx[k], c] # sum over c
-mean(@. x * log(logistic_sigmoid(tmp)) + (1 - x) * log(1 - logistic_sigmoid(tmp))) # sum over k
end
loss (generic function with 1 method)
julia> grad = @btime train(Random.Xoshiro(5), 100277, 100, 2);
1.916 ms (348 allocations: 4.62 MiB)
All definitions of loss
give me zero gradient, so there could be mistakes.