paint-brush
Understanding Multi-Group Attention in AI Modelsby@batching

Understanding Multi-Group Attention in AI Models

by BatchingFebruary 26th, 2025
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

Multi-group attention optimizes AI model efficiency by reducing memory IO costs. FLOPs remain proportional to parameters, ensuring scalability across architectures.

Companies Mentioned

Mention Thumbnail
Mention Thumbnail
featured image - Understanding Multi-Group Attention in AI Models
Batching HackerNoon profile picture
0-item

Authors:

(1) Ben Athiwaratkun, AWS AI Labs;

(2) Sujan Kumar Gonugondla, AWS AI Labs;

(3) Sanjay Krishna Gouda, AWS AI Labs;

(4) Haifeng Qian, AWS AI Labs;

(5) Sanjay Krishna Gouda, AWS AI Labs;

(6) Hantian Ding, AWS AI Labs;

(7) Qing Sun, AWS AI Labs;

(8) Jun Wang, AWS AI Labs;

(9) Jiacheng Guo, AWS AI Labs;

(10 Liangfu Chen, AWS AI Labs;

(11) Parminder Bhatia, GE HealthCare (work done at AWS);

(12) Ramesh Nallapati, Amazon AGI (work done at AWS);

(13) Sudipta Sengupta, AWS AI Labs;

(14) Bing Xiang, Goldman Sachs (work done at AWS).

Abstract and 1 Introduction

2. Related Work

3. Background

3.1. Notation and 3.2. Language Model Inference

3.3. Multi-Query, Multi-Head and the Generalized Multi-Query Attention

4. Context-Aware Bifurcated Attention and 4.1. Motivation

4.2. Formulation and 4.3. Memory IO Complexity

5. Experiments

5.1. Comparing Capabilities of Multi-Head, Multi-Query, and Multi-Group Attention

5.2. Latencies of Capabilities-Equivalent Models

5.3. Applications

6. Conclusion and References


A. FAQs

B. Related Work

C. Setup

D. Multi-Group Attention Family

E. Context-Aware Bifurcated Attention

F. Applications: Additional Results

G. Compatibility with Speculative Decoding and Fast Decoding techniques

D. Multi-Group Attention Family

D.1. Detailed Analysis on Memory Access

We show in Table 4 that the memory IO cost for ⟨q, K⟩ is dominated by the loading of K which costs bmhk in the case of multihead where g = h. This cost is particularly high due to the coupling of batch size b, context length m, and the entire hidden dimension d. Compared to the number of computations, which has complexity bmd, this attention module requires one memory IO per one tensor operation (memory-io bound). In contrast, other operations such as feedforw can be the main bottleneck for incremental decoding and our paper aims to tackle such problems.ard has much lower ratio of memory IO per compute (compute bound). These attention computation


D.2. Model FLOPs

The scaling laws by Kaplan et al. (2020) shows that the modelrelated FLOPs during the forward pass is 2N where N is the number of parameters (without the embeddings). We show that it holds for a general multi-group model as well. The only difference between the multi-group and the multi-head case is the projection PK and PV where they are of size dgk instead of dhk. Since this is a linear layer, the forward pass FLOPs for any input is still proportional such projection size. Therefore, it follows that for any multi-group attention, including multi-head, the forward FLOPs is 2N where N is the respective number of parameters.

D.3. Comparing Capabilities-Equivalent Models

This section outlines the analysis of latency change when we switch from an MH model to an MG model with F times the size.


D.3.1. CONTEXT ENCODING



D.3.2. INCREMENTAL DECODING



Table 2: Model Specifications table presenting architecture details for the three variants: multi head (MH), multi query (MQ), and multi group (MG) including parameter count, number of attention groups, head dimensions, and number of layers. The additional fanout-based MG variant is described here as MG + 2 × d


Table 3: Model Specifications for Latency Experiment in Section 5.2.2.


This paper is available on arxiv under CC BY 4.0 DEED license.