Autores:
(1) Ben Athiwaratkun, Laboratorios de inteligencia artificial de AWS;
(2) Sujan Kumar Gonugondla, Laboratorios de IA de AWS;
(3) Sanjay Krishna Gouda, Laboratorios de inteligencia artificial de AWS;
(4) Haifeng Qian, laboratorios de IA de AWS;
(5) Sanjay Krishna Gouda, Laboratorios de inteligencia artificial de AWS;
(6) Hantian Ding, Laboratorios de inteligencia artificial de AWS;
(7) Qing Sun, Laboratorios de inteligencia artificial de AWS;
(8) Jun Wang, Laboratorios de inteligencia artificial de AWS;
(9) Jiacheng Guo, laboratorios de IA de AWS;
(10 Liangfu Chen, laboratorios de IA de AWS;
(11) Parminder Bhatia, GE HealthCare (trabajo realizado en AWS);
(12) Ramesh Nallapati, Amazon AGI (trabajo realizado en AWS);
(13) Sudipta Sengupta, Laboratorios de inteligencia artificial de AWS;
(14) Bing Xiang, Goldman Sachs (trabajo realizado en AWS).
Tabla de enlaces
3.1. Notación y 3.2. Inferencia del modelo de lenguaje
3.3. Multi-Consulta, Multi-Encabezado y la Atención Multi-Consulta Generalizada
4. Atención bifurcada consciente del contexto y 4.1. Motivación
4.2. Formulación y 4.3. Complejidad de E/S de memoria
5.1. Comparación de las capacidades de atención multi-cabezal, multi-consulta y multi-grupo
5.2. Latencias de los modelos equivalentes a capacidades
D. Atención Multigrupal Familiar
E. Atención bifurcada consciente del contexto
F. Aplicaciones: Resultados adicionales
G. Compatibilidad con técnicas de decodificación especulativa y de decodificación rápida
2. Trabajo relacionado
En la literatura, existen múltiples vías para mejorar la latencia y/o la latencia de inferencia. La cuantificación reduce el uso de memoria mediante el uso de representaciones de bajo ancho de bits como int8, int4 y fp8 (Wei et al., 2023; Yao et al., 2022; Dettmers et al., 2022; Frantar et al., 2022; Kuzmin et al., 2022; Xiao et al., 2022). La cuantificación cuando se aplica solo a los parámetros del modelo ofrece resultados decrecientes como con longitudes de secuencia más largas y tamaños de lote grandes donde el acceso a la memoria y el cómputo asociado con la atención del producto escalar dominan la latencia de inferencia general.
La atención dispersa (Beltagy et al., 2020; Child et al., 2019; Zaheer et al., 2020) se ha estudiado ampliamente como una forma de reducir la complejidad de la atención para contextos más largos e inferencias más rápidas. Pope et al. (2022) investiga la eficiencia de la inferencia generativa de modelos de lenguaje grandes mediante el uso de técnicas de partición multidimensional optimizadas para TPU (einsum colectivo) para lograr una frontera de Pareto en la latencia y la utilización de FLOP del modelo. El artículo también muestra que la atención de múltiples consultas permite escalar hasta una longitud de contexto 32 veces mayor con énfasis en la eficiencia en lotes de gran tamaño. La atención paginada (Kwon et al., 2023) mejora la gestión de la memoria de la caché KV al dividirla en bloques y emplear una tabla de bloques para fines de mapeo. Este enfoque se adapta de manera efectiva a los cambios dinámicos de la carga de trabajo y reduce los requisitos de almacenamiento de memoria al compartir la caché KV del mensaje en múltiples secuencias de salida. Sin embargo, esto no reduce las lecturas de memoria caché KV.
La decodificación especulativa y sus variantes utilizan un modelo de borrador más pequeño para proponer múltiples tokens secuenciales, que son procesados en paralelo por el modelo principal para aceptar o rechazar dichos tokens (Chen et al., 2023; Leviathan et al., 2022; Li et al., 2024; Cai et al., 2024; Fu et al., 2023). La idea clave es permitir la decodificación de múltiples tokens en cada paso, amortizando así los usos de E/S de memoria del modelo principal. Sin embargo, la latencia de la decodificación seguirá estando dominada por el ancho de banda de E/S de caché KV en tamaños de contexto grandes, donde la atención bifurcada puede mejorar aún más la velocidad de decodificación. En resumen, la decodificación incremental se centra en reducir la E/S de memoria amortizada de la carga del modelo, mientras que la consulta múltiple y la atención bifurcada reducen la E/S de memoria de caché KV.
3. Antecedentes
3.1. Notación
Utilizamos la siguiente notación a lo largo del artículo.
3.2. Inferencia del modelo de lenguaje
Existen muchos escenarios de inferencia para el modelo de lenguaje, incluida la inferencia por lotes y el muestreo por lotes de un solo contexto (Figura 1). La inferencia por lotes se refiere al caso en el que procesamos múltiples entradas juntas en un lote y generamos tokens subsiguientes para cada índice de lote de forma independiente. En el caso en el que el tamaño del lote es 1, esto se reduce a la inferencia de un solo contexto. Otro escenario es el muestreo por lotes de un solo contexto, donde generamos múltiples secuencias basadas en un solo contexto, donde la diferencia con el caso de inferencia por lotes es que el prellenado solo debe realizarse para un solo contexto para obtener la caché KV, y luego transmitirse a otros índices de lotes.
La figura 1 también ilustra las dos fases de la inferencia del modelo de lenguaje: (a) la codificación o prellenado del contexto y (b) la decodificación incremental. La codificación del contexto se refiere a un único paso hacia adelante que calcula los tensores de clave y valor para todas las posiciones de token en el contexto. Una vez que se calculan los tensores de clave y valor, almacenamos en caché estos tensores de clave y valor para usarlos en el mecanismo de atención durante la fase de decodificación incremental, que genera secuencialmente un token a la vez[2].
Durante la fase de codificación de contexto, la cantidad de operaciones de punto flotante en relación con las operaciones de entrada/salida (IO) de memoria es alta, lo que corresponde al régimen limitado por el cálculo donde la latencia está influenciada por los FLOP. Sin embargo, durante la decodificación incremental donde prestamos atención a un solo token de consulta, esto cae en un régimen limitado por la memoria donde la cantidad de cálculos por acceso a la memoria es aproximadamente de 1 a 1 (consulte el Apéndice D.1 para obtener más detalles). El IO de memoria se refiere a las operaciones de lectura y escritura desde la memoria de alto ancho de banda (HBM) (Jia et al., 2018) a la SRAM rápida en chip donde ocurre el cálculo real. El IO de memoria de la decodificación incremental en sí consta de dos componentes: (1) la carga de parámetros del modelo y (2) la carga de caché KV. El componente (1) es constante independientemente de la longitud del contexto m o el tamaño del lote b, donde el componente (2) depende tanto de m como de b y domina el IO de memoria general si m o b son altos, lo que puede convertirse en un cuello de botella significativo para la inferencia. Nuestro trabajo se centra principalmente en la reducción del componente (2).
Este artículo está disponible en arxiv bajo la licencia CC BY 4.0 DEED.