๐๐ก๐๐ญโ๐ฌ ๐๐ซ๐จ๐ฎ๐ฉ๐๐-๐๐ฎ๐๐ซ๐ฒ ๐๐ญ๐ญ๐๐ง๐ญ๐ข๐จ๐ง(๐๐๐) ? ๐ ๐ฉ๐๐ฉ๐๐ซ ๐๐ซ๐จ๐ฆ ๐๐จ๐จ๐ ๐ฅ๐ ๐๐๐ฌ๐๐๐ซ๐๐ก
During autoregressive decoding with Transformer models, the main problem is the extra memory bandwidth needed. This is due to the need to load decoder weights, and all attention keys and values, at each step of processing.
An effort to reduce the memory bandwidth overhead of loading keys and values is through ๐ฆ๐ฎ๐ฅ๐ญ๐ข-๐ช๐ฎ๐๐ซ๐ฒ ๐๐ญ๐ญ๐๐ง๐ญ๐ข๐จ๐ง, which involves using multiple query heads with a single key/value head.
However, based on the researchers in the ๐๐๐ ๐ฉ๐๐ฉ๐๐ซ, MQA highlights certain drawbacks. Specifically, utilizing MQA can lead to a decline in quality and introduce training instability. Consequently, attempting to train distinct models optimised separately for quality and inference may not be a practical solution, as stated in the paper.
This is because the primary goal of employing the MQA technique is to accelerate the inference process, making the modification of the entire model architecture and training approach for this purpose impractical.
The paper discusses two key concepts :
1) ๐๐ฉ๐ญ๐ซ๐๐ข๐ง๐ข๐ง๐ ๐จ๐ ๐๐ฑ๐ข๐ฌ๐ญ๐ข๐ง๐ ๐๐๐ ๐๐ก๐๐๐ค๐ฉ๐จ๐ข๐ง๐ญ๐ฌ: Researchers propose a method to transition from a pre-trained model checkpoint using multi-head attention (MHA) to one using multi-query attention (MQA). In this process, the original multiple heads for keys and values in MHA are combined into a single head for both through mean pooling. This approach was found to be superior to randomly initializing key and value heads or selecting one head from the MHA checkpoint.
Then the model is further pre-trained using the MQA checkpoint, but on a small fraction (ฮฑ) of its original training steps, while continuing to follow the same pre-training parameters used in the initial model training (e.g., using the same data, learning rate schedule, optimization algorithm, and other parameters that were used in the original training of the model).
2) ๐๐๐ (๐๐ซ๐จ๐ฎ๐ฉ๐๐-๐๐ฎ๐๐ซ๐ฒ ๐๐ญ๐ญ๐๐ง๐ญ๐ข๐จ๐ง): Another method introduced is GQA, which seeks to strike a balance between MHA and MQA. GQA partitions query heads into G groups, with each group sharing a single key head and value head.
During the transition from a multi-head checkpoint to a GQA checkpoint, the key and value heads for each group are created by averaging the original heads within that group. Essentially, the new key and value heads are generated by mean-pooling the original heads, providing a trade-off between the speed of MQA and the quality of MHA.
For better visualization, feel free to check the image below taken from the official GQA paper.
For more interesting information like this, donโt hesitate to follow me! :)