Multi-Query Attention

Ali Issa
2 min readSep 21, 2023

--

Source

โžก MQA addresses a common challenge faced by models with large context sizes during inference. Typically, ๐ข๐ง๐œ๐ซ๐ž๐š๐ฌ๐ข๐ง๐  ๐ญ๐ก๐ž ๐œ๐จ๐ง๐ญ๐ž๐ฑ๐ญ ๐ฌ๐ข๐ณ๐ž ๐ฅ๐ž๐š๐๐ฌ ๐ญ๐จ ๐ก๐ข๐ ๐ก๐ž๐ซ ๐œ๐จ๐ฆ๐ฉ๐ฎ๐ญ๐š๐ญ๐ข๐จ๐ง๐š๐ฅ ๐œ๐จ๐ฌ๐ญ๐ฌ.

โžก Usually ๐ข๐ง๐œ๐ซ๐ž๐ฆ๐ž๐ง๐ญ๐š๐ฅ ๐ ๐ž๐ง๐ž๐ซ๐š๐ญ๐ข๐จ๐ง is used during ๐ข๐ง๐Ÿ๐ž๐ซ๐ž๐ง๐œ๐ž. In this approach, values are fed into the network one token at a time, and K (keys) and V (values) are computed across the tokens observed so far. However, this method encounters issues when dealing with lengthy input.

โžก To improve latency and reduce computational overhead, various solutions have been introduced. Some involve techniques like ๐Š-๐• ๐œ๐š๐œ๐ก๐ข๐ง๐  (Maintaining computed state across multiple iterations ) or ๐›๐š๐ญ๐œ๐ก๐ข๐ง๐  ๐ฆ๐ฎ๐ฅ๐ญ๐ข๐ฉ๐ฅ๐ž ๐ฌ๐ž๐ช๐ฎ๐ž๐ง๐œ๐ž๐ฌ together during inference without modifying the modelโ€™s architecture.

โžก On the other hand, architecture-focused solutions, such as MQA, emerged in 2019. This technique has been adopted by several #llm (LLMs) like ๐‹๐ฅ๐š๐ฆ๐š2, ๐’๐ญ๐š๐ซ๐‚๐จ๐๐ž๐ซ (a model trained on over 80 programming languages) and ๐…๐š๐ฅ๐œ๐จ๐ง.

โžก MQA brings significant improvements in throughput, allowing the system to process more data within the same timeframe while reducing latency for faster response times. The primary objective is to reduce computation during inference.

โžก ๐Œ๐๐€ ๐•๐’ ๐Œ๐‡๐€
In the traditional ๐Œ๐ฎ๐ฅ๐ญ๐ข-๐‡๐ž๐š๐ ๐€๐ญ๐ญ๐ž๐ง๐ญ๐ข๐จ๐ง (MHA), each Q (query), K (key), and V (value) is divided into multiple vectors equal to the number of heads. Each head then performs the same procedure independently.

โžกHowever, in MQA, we no longer use multiple heads for keys and values. Instead, we utilize multiple heads for queries, and all heads of the query are multiplied by shared K and V vectors, eliminating the need to divide keys and values into multiple vectors (heads).

The key distinction between these techniques lies in the reduced amount of data read/written from memory with MQA.

โžกThis has noteworthy implications for performance, particularly in terms of ๐š๐ซ๐ข๐ญ๐ก๐ฆ๐ž๐ญ๐ข๐œ ๐ข๐ง๐ญ๐ž๐ง๐ฌ๐ข๐ญ๐ฒ ๐ข๐ง๐œ๐ซ๐ž๐š๐ฌ๐ž, which signifies the degree to which data values are reused for a given computation.

โžกThis reuse is particularly evident in MQA, where K and V are reused, whereas in MHA, different heads are called each time for calculations. Additionally, ๐Œ๐๐€ reduces memory space by decreasing the amount of ๐Š๐•-๐œ๐š๐œ๐ก๐ž ๐๐š๐ญ๐š ๐ฌ๐ญ๐จ๐ซ๐ž๐ in memory between iterations of the inference process.

These insights are based on a comprehensive Medium article, and you can find the link to the article at the end of this article, along with the official MQA paper reference.

Furthermore, if anyone wishes to contribute additional information or make modifications to what I have mentioned, please donโ€™t hesitate to do so in the comments section. Your input is welcome and appreciated.

Medium article

Official_Paper : https://arxiv.org/pdf/1911.02150v1.pdf

--

--

No responses yet