Authors:
(1) Ben Athiwaratkun, AWS AI Labs;
(2) Sujan Kumar Gonugondla, AWS AI Labs;
(3) Sanjay Krishna Gouda, AWS AI Labs;
(4) Haifeng Qian, AWS AI Labs;
(5) Sanjay Krishna Gouda, AWS AI Labs;
(6) Hantian Ding, AWS AI Labs;
(7) Qing Sun, AWS AI Labs;
(8) Jun Wang, AWS AI Labs;
(9) Jiacheng Guo, AWS AI Labs;
(10 Liangfu Chen, AWS AI Labs;
(11) Parminder Bhatia, GE HealthCare (work done at AWS);
(12) Ramesh Nallapati, Amazon AGI (work done at AWS);
(13) Sudipta Sengupta, AWS AI Labs;
(14) Bing Xiang, Goldman Sachs (work done at AWS).
3.1. Notation and 3.2. Language Model Inference
3.3. Multi-Query, Multi-Head and the Generalized Multi-Query Attention
4. Context-Aware Bifurcated Attention and 4.1. Motivation
4.2. Formulation and 4.3. Memory IO Complexity
5.1. Comparing Capabilities of Multi-Head, Multi-Query, and Multi-Group Attention
5.2. Latencies of Capabilities-Equivalent Models
D. Multi-Group Attention Family
E. Context-Aware Bifurcated Attention
F. Applications: Additional Results
G. Compatibility with Speculative Decoding and Fast Decoding techniques
Multi-query attention, proposed by Shazeer (2019), is an attention mechanism for transformers models that uses a single head for the key and value tensors, compared to h heads in the traditional multi-head attention (Vaswani et al., 2017). This technique effectively reduces the KV memory IO by h times, which leads to higher inference efficiency during incremental decoding. In effect, the single-head key or value tensor is shared and used to attend to all the multi-head query, hence the name multi-query. This corresponds to a compression in representation power of the key and value tensor, which we will see in the scaling laws study (Section 5.1) that it results in a reduced expressiveness in terms of model parameter efficiency. Such reduced expressiveness can be compensated by scaling the model bigger than the multi-head counterpart to match the representation power.
We can also extrapolate these insights to a generalized multiquery attention mechanism (Ainslie et al., 2023), which provides a framework to understand both multi-query and multi-head attention, and everything in between. Here, the degree of KV compression is dictated by the number of attention groups g, where we alternatively refer to the generalized multi-query as multi-group. Each attention group can be interpreted as the broadcasted attention between a single head of key or value tensor, and multiple heads of query.
In this paradigm, multi-query attention is a special case where the number of groups g = 1; that is, there is exactly one such group. Conversely, multi-head attention is another special case where the number of attention groups matches the number of heads (g = h), in which case each head in the key or value tensor attends to one head in the query. More generally, the number of groups g can lie anywhere between 1 and h, indicating various degrees of compression. For practical purposes, it is most convenient when g divides h. The attention mechanism in this setting can be expressed in terms of Einstein summation as:
logits = ⟨q, K⟩ : einsum(bgpnk, bgmk) → bgpnm (1)
o = ⟨w, V ⟩ : einsum(bgpmn, bgmv) → bgpnv (2)
This generalized multi-group attention mechanism thus provides a unified perspective on the design space of attention architectures. By adjusting the number of attention groups g, one can flexibly tune these trade-offs, potentially yielding new regimes of performance for transformer models. In Section 5.1, we will look into such capability vs latency trade-off.
This paper is available on arxiv under CC BY 4.0 DEED license.