๐–๐ก๐š๐ญโ€™๐ฌ ๐†๐ซ๐จ๐ฎ๐ฉ๐ž๐-๐๐ฎ๐ž๐ซ๐ฒ ๐š๐ญ๐ญ๐ž๐ง๐ญ๐ข๐จ๐ง(๐†๐๐€) ? ๐š ๐ฉ๐š๐ฉ๐ž๐ซ ๐Ÿ๐ซ๐จ๐ฆ ๐†๐จ๐จ๐ ๐ฅ๐ž ๐‘๐ž๐ฌ๐ž๐š๐ซ๐œ๐ก

Ali Issa
2 min readSep 26, 2023

--

During autoregressive decoding with Transformer models, the main problem is the extra memory bandwidth needed. This is due to the need to load decoder weights, and all attention keys and values, at each step of processing.

An effort to reduce the memory bandwidth overhead of loading keys and values is through ๐ฆ๐ฎ๐ฅ๐ญ๐ข-๐ช๐ฎ๐ž๐ซ๐ฒ ๐š๐ญ๐ญ๐ž๐ง๐ญ๐ข๐จ๐ง, which involves using multiple query heads with a single key/value head.

However, based on the researchers in the ๐†๐๐€ ๐ฉ๐š๐ฉ๐ž๐ซ, MQA highlights certain drawbacks. Specifically, utilizing MQA can lead to a decline in quality and introduce training instability. Consequently, attempting to train distinct models optimised separately for quality and inference may not be a practical solution, as stated in the paper.

This is because the primary goal of employing the MQA technique is to accelerate the inference process, making the modification of the entire model architecture and training approach for this purpose impractical.

The paper discusses two key concepts :

1) ๐”๐ฉ๐ญ๐ซ๐š๐ข๐ง๐ข๐ง๐  ๐จ๐Ÿ ๐ž๐ฑ๐ข๐ฌ๐ญ๐ข๐ง๐  ๐Œ๐‡๐€ ๐‚๐ก๐ž๐œ๐ค๐ฉ๐จ๐ข๐ง๐ญ๐ฌ: Researchers propose a method to transition from a pre-trained model checkpoint using multi-head attention (MHA) to one using multi-query attention (MQA). In this process, the original multiple heads for keys and values in MHA are combined into a single head for both through mean pooling. This approach was found to be superior to randomly initializing key and value heads or selecting one head from the MHA checkpoint.

Then the model is further pre-trained using the MQA checkpoint, but on a small fraction (ฮฑ) of its original training steps, while continuing to follow the same pre-training parameters used in the initial model training (e.g., using the same data, learning rate schedule, optimization algorithm, and other parameters that were used in the original training of the model).

2) ๐†๐๐€ (๐†๐ซ๐จ๐ฎ๐ฉ๐ž๐-๐๐ฎ๐ž๐ซ๐ฒ ๐€๐ญ๐ญ๐ž๐ง๐ญ๐ข๐จ๐ง): Another method introduced is GQA, which seeks to strike a balance between MHA and MQA. GQA partitions query heads into G groups, with each group sharing a single key head and value head.

During the transition from a multi-head checkpoint to a GQA checkpoint, the key and value heads for each group are created by averaging the original heads within that group. Essentially, the new key and value heads are generated by mean-pooling the original heads, providing a trade-off between the speed of MQA and the quality of MHA.

For better visualization, feel free to check the image below taken from the official GQA paper.

source : 2305.13245v1.pdf (arxiv.org)

For more interesting information like this, donโ€™t hesitate to follow me! :)

--

--

Responses (1)