著者:
(1)AWS AIラボのBen Athiwaratkun氏
(2) Sujan Kumar Gonugondla、AWS AI Labs。
(3)サンジェイ・クリシュナ・ゴーダ、AWS AIラボ
(4) Haifeng Qian、AWS AI Labs。
(5)サンジェイ・クリシュナ・ゴーダ、AWS AIラボ
(6)Hantian Ding、AWS AIラボ
(7)Qing Sun、AWS AIラボ
(8)ジュン・ワン、AWS AIラボ
(9) Jiacheng Guo、AWS AI Labs。
(10 Liangfu Chen、AWS AI ラボ、
(11)Parminder Bhatia、GEヘルスケア(AWSでの仕事)
(12) Ramesh Nallapati、Amazon AGI (AWS での作業)。
(13)Sudipta Sengupta、AWS AIラボ
(14)ビン・シアン、ゴールドマン・サックス(AWSでの仕事)。
リンク一覧
3.3. マルチクエリ、マルチヘッド、一般化マルチクエリアテンション
5.1. マルチヘッド、マルチクエリ、マルチグループアテンションの機能比較
2. 関連研究
文献には、推論のレイテンシやレイテンシを改善するための複数の方法が記載されています。量子化は、int8、int4、fp8 などの低ビット幅表現を使用してメモリ使用量を削減します (Wei et al., 2023; Yao et al., 2022; Dettmers et al., 2022; Frantar et al., 2022; Kuzmin et al., 2022; Xiao et al., 2022)。量子化をモデル パラメーターにのみ適用すると、シーケンス長が長くバッチ サイズが大きい場合のように、メモリ アクセスとドット積アテンションに関連する計算が全体的な推論レイテンシを支配するため、結果が減少します。
スパースアテンション(Beltagy et al., 2020; Child et al., 2019; Zaheer et al., 2020)は、アテンションの複雑さを軽減して、より長いコンテキストとより高速な推論を行う方法として広く研究されてきました。Pope et al.(2022)は、TPU(集合的アインサム)に最適化された多次元パーティショニング技術を使用して、レイテンシとモデル FLOP 使用率のパレートフロンティアを達成することにより、大規模言語モデルの生成推論効率を調査しています。この論文では、マルチクエリアテンションにより、高バッチサイズでの効率性を重視しながら、コンテキスト長を最大 32 倍まで拡張できることも示されています。ページングアテンション(Kwon et al., 2023)は、KV キャッシュをブロックに分割し、ブロックテーブルをマッピングの目的で使用することで、KV キャッシュのメモリ管理を強化します。このアプローチは、動的なワークロードシフトに効果的に対応し、プロンプトの KV キャッシュを複数の出力シーケンスで共有することにより、メモリストレージ要件を削減します。ただし、これによって KV キャッシュのメモリ読み取りは削減されません。
投機的デコードとそのバリエーションでは、より小規模なドラフト モデルを使用して複数の連続トークンを提案し、メイン モデルによってこれらのトークンが並列処理されて、そのようなトークンを受け入れるか拒否するかを決定します (Chen ら、2023 年、Leviathan ら、2022 年、Li ら、2024 年、Cai ら、2024 年、Fu ら、2023 年)。重要なアイデアは、すべてのステップで複数のトークンのデコードを可能にし、それによってメイン モデルのメモリ IO 使用量を償却することです。ただし、コンテキスト サイズが大きい場合、デコードのレイテンシは依然として KV キャッシュ I/O 帯域幅によって支配されますが、分岐アテンションによってデコード速度をさらに向上させることができます。つまり、増分デコードは、モデルの読み込みの償却メモリ IO を下げることに重点を置いていますが、マルチクエリと分岐アテンションによって KV キャッシュのメモリ IO が下がります。
3. 背景
3.1. 表記
本論文では、以下の表記法を使用します。
3.2. 言語モデル推論
言語モデルには、バッチ推論やシングルコンテキスト バッチ サンプリングなど、多くの推論シナリオがあります (図 1)。バッチ推論とは、複数の入力をバッチでまとめて処理し、バッチ インデックスごとに後続のトークンを個別に生成する場合を指します。バッチ サイズが 1 の場合、これはシングルコンテキスト推論に縮小されます。もう 1 つのシナリオは、単一のコンテキストに基づいて複数のシーケンスを生成するシングルコンテキスト バッチ サンプリングです。バッチ推論の場合との違いは、KV キャッシュを取得するために単一のコンテキストに対してのみ事前入力を実行し、その後他のバッチ インデックスにブロードキャストする必要があることです。
図1は、言語モデル推論の2つのフェーズ、(a)コンテキストエンコーディングまたは事前入力と(b)増分デコードも示しています。コンテキストエンコーディングは、コンテキスト内のすべてのトークン位置のキーテンソルと値テンソルを計算する単一のフォワードパスを指します。キーテンソルと値テンソルが計算されると、これらのキーテンソルと値テンソルをキャッシュし、増分デコードフェーズでアテンションメカニズムに使用します。増分デコードフェーズでは、トークンを1つずつ順番に生成します[2]。
コンテキスト エンコーディング フェーズでは、メモリ入出力 (IO) 操作に対する浮動小数点操作の数が多く、レイテンシが FLOP によって影響を受ける計算依存の領域に対応します。ただし、単一のクエリ トークンに注目する増分デコード中は、メモリ アクセスあたりの計算数がほぼ 1 対 1 になるメモリ依存の領域になります (詳細については付録 D.1 を参照)。メモリ IO とは、高帯域幅メモリ (HBM) (Jia ら、2018) から実際の計算が行われる高速オンチップ SRAM への読み取りおよび書き込み操作を指します。増分デコード自体のメモリ IO は、(1) モデル パラメーターの読み込みと (2) KV キャッシュの読み込みの 2 つのコンポーネントで構成されます。コンポーネント (1) はコンテキスト長 m またはバッチ サイズ b に関係なく一定ですが、コンポーネント (2) は m と b の両方に依存し、m または b が大きい場合は全体のメモリ IO を支配し、推論の大きなボトルネックになる可能性があります。私たちの作業は主にコンポーネント (2) の削減に焦点を当てています。
この論文は、CC BY 4.0 DEED ライセンスの下でarxiv で公開されています。