It’s probably not better to go pure Julia (rather than use state-of-the-art code and algorithms), unless as a learning exercise. But if you do, consider “1-bit networks” (from 2023 and from this week):
It’s very likely if you redo some software, you reimplement an outdated way. E.g. transformers are likely going away in current form.
We’ve likely gone to the end of the line with quantization with such 1-, 2-bit networks, and it helps keep the size down. To stay competitive with training you need thousands of GPUs, and software that can target so many, so it seems out of the question to use pure Julia. But maybe you can go half way there, leave out some parts like distributing to many GPUs, use DeepSpeed or something for that.
Training from scratch is still very costly, so no need to, since you can finetune a model for Julia use. But then you need to choose the best model to start from and formats/quantization as in llama.cpp or this new bitnet.cpp from Microsoft. See on the former (and relation to Llama2.jl):
KAN networks (they can be drop-in replacement into MLP part of transformers, if I recall) are worth-while to reimplement in Julia:
KAN networks are likely not compatible with 1-bit networks, I mean their weights larger, but might still be a good thing, if you get away with fewer. Also I think not intirely contradictory, since you can still have a transformer and other parts with 1-bit weights, where KAN is not replaceing the MLP part. But isn’t the MLP part the largest part of the total?
I think also worthwhile to help with this:
Best models will likely use new ways of multiplying not yet in software (but you could emulate slowly(?) for compatibility until hardware catches up, or maybe just use Float8, of bflot, I don’t recall, might be compatible with it): https://arxiv.org/html/2410.00907v2#S2
Would there be demand for pure Julia implementations? Definitely!
My understanding was that people are stretched so thin on existing projects that we need more people interested and willing to hack!
Personally, I’m crazy about the applications of GenAI and building on top of it rather than training, but I bet that differs for everyone.
If you’re keen to hack deeper than just pure inference, but still easy to start maybe you wanna dip your feet with Entropix? Have you played with it? It does a lot of clever stuff with really small models - having that perfect balance of performance, practical, and runnable locally. That could be a fun starter!
Funny how the link there is to a paper from this month, so “known for a little while” in AI research that means what, about 3 weeks?! [Or does the new paper reference older paper/ideas?] Not really to surprising with the rapid changes even if meaning 3 weeks.
Thanks, I didn’t know of Entropix, seems interesting.
Last I heard, entropix is splitting the repository, one effort going toward huge models and pushing the limits for where this can go. The other is focused on local LLMs, squeezing out every last drop of intelligence.
Consider implementing DIFF Transformer from October paper (it seems better across the board), I started translating to Julia:
function DiffAttn(X, W_q, W_k, W_v, λ)
Q1, Q2 = split(X * W_q)
K1, K2 = split(X * W_k)
V = X * W_v
# Qi, Ki: [b, n, d]; V: [b, n, 2d]
s = 1 / sqrt(d) # torch.rsqrt
A1 = Q1 * K1.transpose(−1, −2) * s
A2 = Q2 * K2.transpose(−1, −2) * s
return (softmax(A1) − λ * softmax(A2)) * V
end
I leave in the Python pseudocode as is:
def MultiHead(X, W_q, W_k, W_v, W_o, λ):
O = GroupNorm([DiffAttn(X, W_qi, W_ki,
W_vi, λ) for i in range(h)])
O = O ∗ (1 − λinit)
return Concat(O) @ W_o
It allows for more “post-training quantization”, with the current transformer suffering (way more, see figure 8) at 4-bit. Best quantization I’ve seen is even less than 4-bit, but I think never for post-training, and I think this will also work, no less, probably even better when combined with such methods:
The results indicate that DIFF Transformer natively mitigates activation outliers in attention scores, providing new opportunities for low-bit FlashAttention [8] implementations.
Figure 1: Transformer often over-attends to irrelevant context (i.e., attention noise). DIFF Transformer amplifies attention to answer spans and cancels noise, enhancing the capability of context modeling.
In this [Microsoft] paper, we introduce Differential Transformer (a.k.a. DIFF Transformer), a foundation architecture for large language models. […] Specifically, we partition the query and key vectors into two groups and compute two separate softmax attention maps. […] The differential attention mechanism eliminates attention noise, encouraging models to focus on critical information. The approach is analogous to noise-canceling headphones and differential amplifiers [19] in electrical engineering, where the difference between two signals cancels out common-mode noise. In the middle of Figure 1, we also present the normalized distribution of attention scores for DIFF Transformer. We observe that DIFF Transformer assigns significantly higher scores to the correct answer and much lower scores to irrelevant context compared to Transformer. […] We conduct extensive experiments on language modeling. We scale up DIFF Transformer in terms of parameter count, training tokens, and context length. The scaling curves indicate that DIFF Transformer requires only about 65% of model size or training tokens needed by Transformer to achieve comparable language modeling performance. Moreover, DIFF Transformer outperforms Transformer in various downstream tasks. The long-sequence evaluation also shows that DIFF Transformer is highly effective in utilizing the increasing context. In addition, the experimental resultsdemonstrate that DIFF Transformer has intriguing advantages for large language models. For example, the proposed method substantially outperforms Transformer in key information retrieval, hallucination mitigation, and in-context learning. DIFF Transformer also reduces outliers in model activations, which provides new opportunities for quantization. The findings establish DIFF Transformer as an effective and distinctive foundation architecture for large language models.
Figure 3: Language modeling loss of scaling up parameter count and training tokens. DIFF Transformer requires only about 65% of model size or training tokens to match Transformer’s performance.
3.5 In-Context Learning
We evaluate in-context learning from two perspectives, including many-shot classification and robustness of in-context learning. In-context learning is a fundamental capability of language models, which indicates how well a model can utilize input context.
[…]
The results show that DIFF Transformer consistently outperforms Transformer across datasets and varying numbers of demonstration samples. Moreover, the improvement in average accuracy is substantial, ranging from 5.2% to 21.6%.
Robustness of In-Context Learning Figure 7 compares the robustness of in-context learning between Transformer and DIFF Transformer. […]
The results indicate that our approach is more robust for in-context learning. In contrast, Transformer tends to be distracted by order permutations [ 25], resulting in a huge margin between the best and worst results.
[…]
Compared with Transformer, our method mitigates contextual hallucination on summarization and question answering. The performance improvement possibly stems from DIFF Transformer’s better focus on essential information needed for the task, instead of irrelevant context. This aligns with previous observation [16] that one primary reason for contextual hallucination in Transformer is the misallocation of attention scores.
F Gradient Flow of DIFF Transformer
We show that the gradient flow in differential attention is similar to that of conventional softmax attention. With this property, the same hyperparameters used in Transformer can be applied directly to the corresponding DIFF Transformer without concerns about training instability
They use “internal version of [21]” meaning SuperScaler (but it seems optional, only for training speed):
SuperScaler: Supporting Flexible DNN Parallelization via a Unified Abstraction https://arxiv.org/pdf/2301.08984
As a result, SuperScaler can not only generate empirical paral-
lelization plans, but also construct new plans that achieve up
to 3.5× speedup compared to state-of-the-art solutions like
DeepSpeed, Megatron and Alpa torch.rsqrt — PyTorch 2.5 documentation
Available RMSNorm — PyTorch 2.5 documentation so why do they implement their own RMSNorm? I suppose such is available in Julia, and most of (such foundation) code they rely on.
It references (and uses Group normalization [paper]) and:
Magneto: A foundation Transformer. In International Conference on Machine Learning, pp. 36077–36092. PMLR, 2023.
Also intriguing paper from days ago, and likely can be combined with the paper I just posted:
In this work, we introduce BitNet a4.8, a hybrid quantization and sparsification strategy that enables 4-bit activations for 1-bit LLMs. By carefully analyzing the activation distribution of 1-bit LLMs, we selectively apply 4-bit quantization or sparsification based on the distribution patterns of these activations. Specifically, as shown in Figure 1, BitNet a4.8 employs 4-bit activations for the inputs to attention and FFN, while utilizing sparsification with 8 bits for intermediate states. To improve the training efficiency, BitNet a4.8 is trained from 8-bit to 4-bit activations with a two-stage recipe, which requires only a few training tokens to adapt BitNet b1.58 to the low-bit activations at the end of training. Extensive experiments demonstrate that BitNet a4.8 achieves competitive performance to BitNet b1.58 with the same training cost while being significantly more efficient at inference time.
Our proposed technique can be directly applied to existing LLMs without any modifications to the pre-training setup or additional fine-tuning.
Now, you can process 1M context 10x faster in a single A100 using Long-context LLMs like LLaMA-3-8B-1M, GLM-4-1M, with even better accuracy, try MInference 1.0 right now!
Rather cool:
Text-to-Video Generation (1280x768, 10s, 24fps)
Others do longer up to 2 min, but at 24fps (I think way lower)? What’s the max fps?
Have there been other key/substantial improvements recently in the other components across the llm stack? Are most of the key changes in the transformer blocks? Is the attention mechanism mostly the same across implementations?
If you’re asking me, there so much I can write on this, but I want to be confident in my answers, so I let others answer.
About training, I thought, that training scratch, would be basically an impossible task, and also missing infrastructure code (in Julia if we can not use already available). This seems like a huge deal from August:
This is the repository for DisTrO (Distributed Training Over-The-Internet), a family of low latency distributed optimizers that reduce inter-GPU communication requirements by three to four orders of magnitude.
This means people could help out together, I think with their home GPUs, but it’s unclear it lowers GPU requirements though, so we can likely not get 300,000 GPUs or something (or that many people) to help.
Either you fine-tune (doable) or from scratch, but I’m thinking do only those extreme exist? I suppose you do not start totally from scratch all the time, would be best to start from some early checkpoint.
I’ve not looked into DisTrO closely, is it a replacement for Adam, Lion etc. (likely not) or builds on such? Probably at least replaces Deepspeed.
Do we have all the variants, for sure unlikely, do you mean like:
Better performance with lower precision: FlashAttention-3 can work with lower precision numbers (FP8) while maintaining accuracy.
Such lower precision requires recent hardware (GPUs), and I don’t think we can compete if we do not use/target such. We are also behind in even more quantized models. We have SafeTensors.jl but that format and code is limited to bfloat16 and FP8 smallest (or so it seems, maybe not inherently and will support smaller?). It uses DLFP8Types.jl so seems implemented in software (then slowly).
Did you mean we need Grouped Query Attention, I’m not sure we might have it already?
… Tri’s publication history has been leaning toward SSM and Mamba style architectures recently. Unlike Flash Attention which has quadratic time complexity wrt sequence length, these latest algorithms are subquadratic. Thus they do much less computation, instead of just doing it more efficiently a la Flash Attention.
Dao and Gu published a really long paper this year which demonstrated (among other things) how Mamba/SSM can be formulated such that it’s amenable to acceleration using the same hardware primitives that Transformers benefit from. …
…Until the strong exponential hypothesis is (dis-)proven, the quadratic cost is required or you have to give something up. Just the cost of exhaustive search.
As (dis-)proving SETH will resolve the P vs NP problem, I wouldn’t hold my breath. …
Maybe Mamba, SSM or Jamba or some linear transformer will take over, but it seems to me just sticking with quadratic transformer is a safe bet (or even if not good enough code can later be changed?).
We have many (most?) optimizers here, e.g. many Adam variants:
With version 0.4 the default update rule for AdamW has changed to match the pytorch implementation.
We even have API · Optimisers.jl (once the best I though, no longer?), but clicking on some docs such as it shows strange (placeholder?) text:
In addition to the main course, you may wish to order some of these condiments:
I thought we redundantly have them at SciML, but its docs actually link over to flux.ml:
In terms of ‘competing’, I am thinking of getting into the race track from behind and with a good stack design/architecture be able to have iterative contributions to improve upon it. Almost like an end-to-end scaffolding. As long as the modularity and separation of concerns of the stack components is well thought through it should allow others to plug in their improved components.
At the moment what are the components which don’t exist yet? I would be great if there was a ‘roadmap’ / list of ‘milestones’ and then we could see what needs to be done and get updates. Most importantly though I am thinking that the ‘design’ for seeing how the components can link together is vital for the operation of the stack.
What is the list of ‘components’ we need and how will they interface together?
Just looking at the diversity and breadth of libraries and packages in the equivalent Python ecosystem, I don’t think there will be a single design or architecture to rule them all here. It might be better to start with a concrete example and use case that includes all the pieces you want to see in the stack, then identify and fill in gaps from there. Llama2.jl is a great example of this: it tried to fill the niche of a llama.cpp equivalent in the Julia ecosystem.
You are very right that a single approach is not very ‘organic’ and that only ecosystem will prevail in producing a great wealth of libraries. I totally agree with a ‘concrete example and use case that includes all the pieces’.
From that, the question is, what is a list of the ‘repos’/‘pieces’ that would be good to have to cover a full stack? A list would be great, or a diagram, with a few high level bullet points, knowing that it is subject to change etc. Then the interface would be good so that others that want to substitute a component can easily do so. Eg. if the new ‘distributed’ approach is sought after to do the training, a good abstraction within the stack and clear defined interface should allow a substitution to try that out. A big monolithic approach may hamper the efforts of specialized development but of course it should not be too ‘shallow’ and scattered. Do you have an idea for a basic outline of a stack we could work on?
This difficulty of analysis and requirements gathering is precisely why I’d recommend focusing on one (maybe two) concrete examples first. For example, maybe you want to pre-train a new open source LLM on a specific set of datasets, evaluate it on a common benchmark suite and quantize it to run inference. All of those steps currently have gaps in the ecosystem, but the only way to figure out what exactly the gaps are is to start writing code. It’s less work to extract a common stack out from those afterwards than it is to try to come up with an ideal one before any code exists to use it.
Reactant would be the way to go if you want good performance for these workloads. If you want a starter code, we have some WIP versions scattered across PRs atm
Even the quantized ops needed for inference exist in the StableHLO land but we haven’t hooked them up yet on the Julia side but it is definitely doable.
Thank you so much for referring to my recent Medium article here! I’m excited to share that this is part of an ongoing series where I translate the Python code from Sebastian Raschka’s Build a Large Language (from Scratch) into Julia.
Apart from the recent article on the self-attention mechanism (section 3.3.1 of the book), I’ve also written 7 more Medium articles where the Python and Julia code from the second chapter is translated and explained. You can find the full code in the GitHub repo, which I’ve just set up to accompany the article series.
I’d love to connect on Bluesky and X to exchange ideas and collaborate further. If you find this work useful, please feel free to star/share the GitHub repo to reach more users who might want to join this effort of translating Python to Julia for building an LLM from scratch.
P.S.: I always share both the member link and the (free) friend link for the Medium articles, in case you’re not a Medium member.
To deliver high performance while meeting accuracy goals, Trainium is optimized for FP32, TF32, BF16, FP16, UINT8, and the new configurable FP8 (cFP8) data type.
Despite seeing AWS has their own chip, I think they’re behind the curve of new research. Nvidia has FP8 and INT4 and their Blackwell’s FP4 Tensor Core (I think it also has FP6), but it’s also outdated.
Something to emulate in Julia, or get access to such software:
1-Bit FQT: Pushing the Limit of Fully Quantized Training to 1-bit
To explore the ultimate limit of FQT (the lowest achievable precision), we make a first attempt to 1-bit FQT. We provide a theoretical analysis of FQT based on Adam and SGD, revealing that the gradient variance influences the convergence of FQT. Building on these theoretical results, we introduce an Activation Gradient Pruning (AGP) strategy. […] Additionally, we propose Sample Channel joint Quantization (SCQ), which utilizes different quantization strategies in the computation of weight gradients and activation gradients to ensure that the method is friendly to low-bitwidth hardware. Finally, we present a framework to deploy our algorithm. For fine-tuning VGGNet-16 and ResNet-18 on multiple datasets, our algorithm achieves an average accuracy improvement of approximately 6%, compared to per-sample quantization. Moreover, our training speedup can reach a maximum of 5.13× compared to full precision training. Ours code is available at GitHub - Gaochang-bjtu/1-bit-FQT
As the training numerical precision continues to decrease, a natural question arises:
What is the ultimate limit of FQT (i.e., the minimum achievable bitwidth)?
Answering this question not only advances our understanding of FQT but also provides a crucial direction for future hardware design strategies. Ideally, if we can push the bitwidth down to 1-bit, the training can be implemented with binary operations, such as XNOR and bitcounting operations Courbariaux et al. [2016], and hardware design might be greatly simplified. Binary computation is already shown possible for inference acceleration, such as XNOR-Net Rastegari et al. [2016], but 1-bit training remains unexplored.
Reducing the bitwidth for FQT is challenging because of (1) the lack of theoretical understanding, especially how gradient quantization affects the convergence; (2) the large quantization error of gradients, which causes a sharp performance drop or even divergence when reducing gradient bitwidth lower than 4-bit (Fig. 1).
Due to these challenges, current research frontier is still 4-bit FQT. In this work, we make a first attempt towards achieving 1-bit FQT.
Specifically, our analysis reveals that Adam is more suitable for FQT than SGD in the low-bitwidth regime, due to their different sensitivity to gradient
variance.
Inspired by the above theory, we propose a hardware-friendly algorithm for 1-bit FQT. […]
We examine the potential of 1-bit FQT on transfer learning tasks in both vision and NLP domain. […]
On all the datasets, our 1-bit FQT algorithm can successfully converge and demonstrate significantly superior performance compared to directly applying the previous FQT method to the task. The average accuracy drop on visual classification datasets is approximately 5%, compared to training the binary model with full-precision gradients. Notably, the average accuracy loss is
negligible (less than 1%) on Flowers Nilsback and Zisserman [2008] dataset and Pets Parkhi et al. [2012] dataset, indicating that 1-bit FQT might indeed be useful in some cases. We implement our algorithm on Hygon and Raspberry Pi devices as a PyTorch-based library binop. Accelerated on-device training can be achieved with simple layer substitution, e.g., replace torch.nn.Conv2d with binop.Conv2d. In practice, our method can achieve up to 5.13× speedup, compared to FP32 PyTorch. These results indicate that in some specific tasks, FQT precision can be pushed to the ultimate 1-bit
[…] 3.1 Quantized Training
Here, we describe Quantization-Aware Training (QAT) and Fully Quantized Training (FQT). QAT is employed to accelerate inference, while FQT is designed to accelerate both inference and training.
[…] 8-bit PSQ vs. Ours. To demonstrate the advantages of our method over other high-bit-width FQT methods, we compare our approach with 8-bit PSQ in terms of both speedup and classification performance (there is no 4-bit format among the standard data types).
This repository contains the official implementation of Half-Quadratic Quantization (HQQ) presented in our articles:
HQQ is a fast and accurate model quantizer that skips the need for calibration data. Quantize the largest models, without calibration data, in just a few minutes at most .
FAQ Why should I use HQQ instead of other quantization methods?
HQQ is very fast to quantize models.
It supports 8,4,3,2,1 bits.
You can use it on any model (LLMs, Vision, etc.).
[…]
What is the quality of the quantized models?
We have detailed benchmarks on both language and vision models. Please refer to our blog posts: HQQ, HQQ+.
What is the speed of the quantized models?
4-bit models with axis=1 can use optimized inference fused kernels like torchao’s int4_gemm. This is the same kernel used in gpt-fast and based on our benchmarks, it’s the fastest kernel available right now. We also support the Marlin kernel. Moreover, we focus on making hqq fully compatible with torch.compile which speeds-up both training and inference.
In case 1-bit (up to 3-bit) isn’t optimal just yet, then standard floating point is neither or FP4, but:
Learning from Students: Applying t-Distributions to Explore Accurate and Efficient Formats for LLMs
… Yet recently, alternative formats such as Normal Float (NF4) have increased model accuracy at the cost of increased chip area. In this work, we first conduct a large-scale analysis ofstrong text LLM weights and activations across 30 networks and conclude that most distributions follow a Student’s t-distribution. We then derive a new theoretically optimal format, Student Float (SF4), that improves over NF4 across modern LLMs, for example increasing the average accuracy on LLaMA2-7B by 0.76% across tasks
That is incredible research. I was unaware of that progress. Thanks for sharing.
It should be possible to allow the flexibility to work with the 1bit model and the typical FP32/16? Will that be possible with the Flux setup currently? This approach would definitely make it possible for training to be done without serious funding or corporate support that might be necessary for anything beyond a gpt2 version, produced just for demonstrative purposes.