Authors:
(1) Ben Athiwaratkun, AWS AI Labs;
(2) Sujan Kumar Gonugondla, AWS AI Labs;
(3) Sanjay Krishna Gouda, AWS AI Labs;
(4) Haifeng Qian, AWS AI Labs;
(5) Sanjay Krishna Gouda, AWS AI Labs;
(6) Hantian Ding, AWS AI Labs;
(7) Qing Sun, AWS AI Labs;
(8) Jun Wang, AWS AI Labs;
(9) Jiacheng Guo, AWS AI Labs;
(10 Liangfu Chen, AWS AI Labs;
(11) Parminder Bhatia, GE HealthCare (work done at AWS);
(12) Ramesh Nallapati, Amazon AGI (work done at AWS);
(13) Sudipta Sengupta, AWS AI Labs;
(14) Bing Xiang, Goldman Sachs (work done at AWS).
3.1. Notation and 3.2. Language Model Inference
3.3. Multi-Query, Multi-Head and the Generalized Multi-Query Attention
4. Context-Aware Bifurcated Attention and 4.1. Motivation
4.2. Formulation and 4.3. Memory IO Complexity
5.1. Comparing Capabilities of Multi-Head, Multi-Query, and Multi-Group Attention
5.2. Latencies of Capabilities-Equivalent Models
D. Multi-Group Attention Family
E. Context-Aware Bifurcated Attention
F. Applications: Additional Results
G. Compatibility with Speculative Decoding and Fast Decoding techniques
We show in Table 4 that the memory IO cost for ⟨q, K⟩ is dominated by the loading of K which costs bmhk in the case of multihead where g = h. This cost is particularly high due to the coupling of batch size b, context length m, and the entire hidden dimension d. Compared to the number of computations, which has complexity bmd, this attention module requires one memory IO per one tensor operation (memory-io bound). In contrast, other operations such as feedforw can be the main bottleneck for incremental decoding and our paper aims to tackle such problems.ard has much lower ratio of memory IO per compute (compute bound). These attention computation
The scaling laws by Kaplan et al. (2020) shows that the modelrelated FLOPs during the forward pass is 2N where N is the number of parameters (without the embeddings). We show that it holds for a general multi-group model as well. The only difference between the multi-group and the multi-head case is the projection PK and PV where they are of size dgk instead of dhk. Since this is a linear layer, the forward pass FLOPs for any input is still proportional such projection size. Therefore, it follows that for any multi-group attention, including multi-head, the forward FLOPs is 2N where N is the respective number of parameters.
This section outlines the analysis of latency change when we switch from an MH model to an MG model with F times the size.
D.3.1. CONTEXT ENCODING
D.3.2. INCREMENTAL DECODING
This paper is available on arxiv under CC BY 4.0 DEED license.