In short, no I think for quantized, and might even be limited to FP32 in practice, but more on this at the bottom. I don’t think Float32 is used at all anymore in any mainstream or research (since 2x, or 4x+ slower), except by Julia still?! But even Julia’s Float16 might be ok for KANs:
My thinking is if we want to do a “Julia LLM from scratch” project at all, for training (also even for just inference) it needs to be competitive, using recent methods, or why bother? Or even better, be ahead of the curve, as by using KAN. There’s no need or even helpful to implement old ideas if/since they take much longer to train, with naive or old ideas.
Quantization has been the go-to method to allowing faster inference (allow more parameters in same fixed amount of [GPU/TPU] memory), and can actually also help for training I believe by now.
So, I’ve been thinking does it or can it even apply to KAN (or other new ideas, I do not want to lead us down the wrong path), it seems so:
Hardware Acceleration of Kolmogorov–Arnold Network (KAN) for Lightweight Edge Inference
https://arxiv.org/pdf/2409.11418
Acceptance date: September 2, 2024
Recently, a novel model named Kolmogorov-Arnold Networks (KAN) has been proposed with the potential to achieve the functionality of traditional deep neural networks (DNNs) using orders of magnitude fewer parameters by parameterized B-spline functions with trainable coefficients. However, the B-spline functions in KAN present new challenges for hardware acceleration. Evaluating the B-spline functions can be performed by using look-up tables (LUTs) to directly map the B-spline functions, thereby reducing computational resource requirements. However, this method still requires substantial circuit resources (LUTs, MUXs, decoders, etc.). For the first time, this paper employs an algorithm-hardware co-design methodology to accelerate KAN. The proposed algorithm-level techniques include Alignment-Symmetry and PowerGap KAN hardware aware quantization, KAN sparsity aware mapping strategy, and circuit-level techniques […]
with analog-CIM (ACIM) circuits. The impact of non-ideal effects, such as partial sum errors caused by the process variations, has been evaluated with the statistics measured from the TSMC 22nm RRAM-ACIM prototype chips. With the best searched hyperparameters of KAN and the optimized circuits implemented in 22 nm node, we can reduce hardware area by 41.78x, energy by 77.97x with 3.03% accuracy boost compared to the traditional DNN hardware
We of course don’t have access to exotic analog prototype chips, but I think that’s ok, we could start with non-quantized KANs, as has been done in Python (with only 285 lines in the code below) and then later add quantization to 8-bit also just on regular hardware, otherwise as done here:
We propose an Alignment-Symmetry and PowerGap KAN hardware aware quantization that, for the first time, investigates the interaction between quantization grid and knot grid in KAN. The proposed method significantly minimizes the cost of LUTs, MUXs, decoders for B(X) function.
[…]
our focus is on accelerating wsspline(x) computation. In our implementation, ws is multiplied with ci and becomes ci’, then is quantized to 8-bit, transforming the formula to equation (3)
Note KANs are a drop in replacement for MLPs, i.e. a critical part of transformers and more.
It might seems like down to only 8 bits is not good, 2x to 8x larger than competing mainstream, but it’s a win if parameter count is reduced by at least as much e.g. 8x+. And I think it must have already been considered a win (for space) before quantization was introduced, given “orders of magnitude fewer parameters”, likely is a huge win for space; and for compute?
For backpropagation purposes, it is also convenient to store the output derivative with respect to the inputs:
However, all activation functions are linear combination of a fixed set of basis functions which are B-splines; given that, we can reformulate the computation as activate the input with different basis functions and then combine them linearly. This reformulation can significantly reduce the memory cost and make the computation a straightforward matrix multiplication, and works with both forward and backward pass naturally.
The problem is in the sparsification which is claimed to be critical to KAN’s interpretability.
The activations were the bottleneck before, with KAN or anything newer, then the bottleneck might shift. Still the current mainstream quantization (and pruning and other sparsification) might work for it, complement KAN that might only be used for some of the parameters.
My other worry is, should we go for more brain-inspired (e.g. the brain doesn’t use backpropagation, provably, considered not plausible, though I’ve seen it might actually be happening…)? I.e. is current AI ML/LLMs on the wrong path, as argued for in Jeff Hawkins’s excellent book (that I’ve read, and his earlier, I’ve not)? It’s not based on spiking neurons. There is already work on spiking neural networks in Julia, and also some spiking neural network hardware available.
Not read, $254 is pretty steep, though used cheaper and $75 on Kindle:
https://arxiv.org/html/2408.14811v1
BitNet: LLM Quantization at its Extreme Kolmogorov-Arnold Networks
Explaining Paper: [2410.23168] TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters
Julia has some bfloat16 native support by now (in 1.11) Adapt to upstream changes wrt. native support for BFloat16 by maleadt · Pull Request #51 · JuliaMath/BFloat16s.jl · GitHub for e.g. AMD EPYC 9554 CPU, but I’m not sure good enough, since you want to use GPUs or TPUs anyway, and I see:
Also, if I understand correctly, CUDA.jl doesn’t fully support bloat16, which is quite limiting.
don’t think Flux uses mixed-precision, so probably no. It is possible to configure CUDA.jl to use tensor cores more eagerly, at the expense of some precision, by starting Julia with fast math enabled or by calling
CUDA.math_mode!(CUDA.FAST_MATH)
, which will e.g. use TF32 when doing an F32xF32 matmul. Further speed-ups are possible by setting CUDA.jl’s math precision to:BFloat16
or even:Float16
. Ideally though, I guess Flux.jl would have an interface to use mixed-precision arithmetic.
I don’t think Float32 is used at all anymore (since 2x slower), except by Julia still?! Usually brainfloat bfloat16, NOT the same as Float16 in Julia, which is not as good. bfloat16 is also outdated for inference, and I think also for training by now.
NOT needed to understand or implement:
KANQAS: Kolmogorov-Arnold Network for Quantum Architecture Search
https://arxiv.org/pdf/2406.17630
Quantum architecture search (QAS) is a promising direction for optimization and automated design of quantum circuits towards quantum advantage. Recent techniques in QAS focus on machine learning-based approaches from reinforcement learning, like deep Q-network. […]
Moreover, in noisy scenarios, KAN can achieve a better fidelity in approximating maximally entangled state than MLPs, where the performance of the MLP significantly depends on the choice of activation function. In tackling quantum chemistry problems, we enhance the recently proposed QAS algorithm by integrating Curriculum Reinforcement Learning (CRL) with a KAN structure instead of the traditional MLP. This modification allows us to design a parameterized quantum circuit that contains fewer 2-qubit gates and has a shallower depth, thereby improving the efficiency of finding the ground state of a chemical Hamiltonian. Further investigation reveals that KAN requires a significantly smaller number of learnable parameters compared to MLPs; however, the average time of executing each episode for KAN is higher.