Consider implementing DIFF Transformer from October paper (it seems better across the board), I started translating to Julia:
function DiffAttn(X, W_q, W_k, W_v, λ)
Q1, Q2 = split(X * W_q)
K1, K2 = split(X * W_k)
V = X * W_v
# Qi, Ki: [b, n, d]; V: [b, n, 2d]
s = 1 / sqrt(d) # torch.rsqrt
A1 = Q1 * K1.transpose(−1, −2) * s
A2 = Q2 * K2.transpose(−1, −2) * s
return (softmax(A1) − λ * softmax(A2)) * V
end
I leave in the Python pseudocode as is:
def MultiHead(X, W_q, W_k, W_v, W_o, λ):
O = GroupNorm([DiffAttn(X, W_qi, W_ki,
W_vi, λ) for i in range(h)])
O = O ∗ (1 − λinit)
return Concat(O) @ W_o
Microsoft’s Python code here: unilm/Diff-Transformer at master · microsoft/unilm · GitHub
It allows for more “post-training quantization”, with the current transformer suffering (way more, see figure 8) at 4-bit. Best quantization I’ve seen is even less than 4-bit, but I think never for post-training, and I think this will also work, no less, probably even better when combined with such methods:
The results indicate that DIFF Transformer natively mitigates activation outliers in attention scores, providing new opportunities for low-bit FlashAttention [8] implementations.
https://arxiv.org/pdf/2410.05258
Figure 1: Transformer often over-attends to irrelevant context (i.e., attention noise). DIFF Transformer amplifies attention to answer spans and cancels noise, enhancing the capability of context modeling.
In this [Microsoft] paper, we introduce Differential Transformer (a.k.a. DIFF Transformer), a foundation architecture for large language models. […] Specifically, we partition the query and key vectors into two groups and compute two separate softmax attention maps. […] The differential attention mechanism eliminates attention noise, encouraging models to focus on critical information. The approach is analogous to noise-canceling headphones and differential amplifiers [19] in electrical engineering, where the difference between two signals cancels out common-mode noise. In the middle of Figure 1, we also present the normalized distribution of attention scores for DIFF Transformer. We observe that DIFF Transformer assigns significantly higher scores to the correct answer and much lower scores to irrelevant context compared to Transformer. […] We conduct extensive experiments on language modeling. We scale up DIFF Transformer in terms of parameter count, training tokens, and context length. The scaling curves indicate that DIFF Transformer requires only about 65% of model size or training tokens needed by Transformer to achieve comparable language modeling performance. Moreover, DIFF Transformer outperforms Transformer in various downstream tasks. The long-sequence evaluation also shows that DIFF Transformer is highly effective in utilizing the increasing context. In addition, the experimental results demonstrate that DIFF Transformer has intriguing advantages for large language models. For example, the proposed method substantially outperforms Transformer in key information retrieval, hallucination mitigation, and in-context learning. DIFF Transformer also reduces outliers in model activations, which provides new opportunities for quantization. The findings establish DIFF Transformer as an effective and distinctive foundation architecture for large language models.
Figure 3: Language modeling loss of scaling up parameter count and training tokens. DIFF Transformer requires only about 65% of model size or training tokens to match Transformer’s performance.
3.5 In-Context Learning
We evaluate in-context learning from two perspectives, including many-shot classification and robustness of in-context learning. In-context learning is a fundamental capability of language models, which indicates how well a model can utilize input context.
[…]
The results show that DIFF Transformer consistently outperforms Transformer across datasets and varying numbers of demonstration samples. Moreover, the improvement in average accuracy is substantial, ranging from 5.2% to 21.6%.
Robustness of In-Context Learning Figure 7 compares the robustness of in-context learning between Transformer and DIFF Transformer. […]
The results indicate that our approach is more robust for in-context learning. In contrast, Transformer tends to be distracted by order permutations [ 25], resulting in a huge margin between the best and worst results.
[…]
Compared with Transformer, our method mitigates contextual hallucination on summarization and question answering. The performance improvement possibly stems from DIFF Transformer’s better focus on essential information needed for the task, instead of irrelevant context. This aligns with previous observation [16] that one primary reason for contextual hallucination in Transformer is the misallocation of attention scores.
F Gradient Flow of DIFF Transformer
We show that the gradient flow in differential attention is similar to that of conventional softmax attention. With this property, the same hyperparameters used in Transformer can be applied directly to the corresponding DIFF Transformer without concerns about training instability
They use “internal version of [21]” meaning SuperScaler (but it seems optional, only for training speed):
SuperScaler: Supporting Flexible DNN Parallelization via a Unified Abstraction
https://arxiv.org/pdf/2301.08984
As a result, SuperScaler can not only generate empirical paral-
lelization plans, but also construct new plans that achieve up
to 3.5× speedup compared to state-of-the-art solutions like
DeepSpeed, Megatron and Alpa
torch.rsqrt — PyTorch 2.5 documentation
Available RMSNorm — PyTorch 2.5 documentation so why do they implement their own RMSNorm? I suppose such is available in Julia, and most of (such foundation) code they rely on.
It references (and uses Group normalization [paper]) and:
Magneto: A foundation Transformer. In International Conference on Machine Learning, pp. 36077–36092. PMLR, 2023.
and: