Autores:
(1) Ben Athiwaratkun, Laboratorios de inteligencia artificial de AWS;
(2) Sujan Kumar Gonugondla, Laboratorios de IA de AWS;
(3) Sanjay Krishna Gouda, Laboratorios de inteligencia artificial de AWS;
(4) Haifeng Qian, laboratorios de IA de AWS;
(5) Sanjay Krishna Gouda, Laboratorios de inteligencia artificial de AWS;
(6) Hantian Ding, Laboratorios de inteligencia artificial de AWS;
(7) Qing Sun, Laboratorios de inteligencia artificial de AWS;
(8) Jun Wang, Laboratorios de inteligencia artificial de AWS;
(9) Jiacheng Guo, laboratorios de IA de AWS;
(10 Liangfu Chen, laboratorios de IA de AWS;
(11) Parminder Bhatia, GE HealthCare (trabajo realizado en AWS);
(12) Ramesh Nallapati, Amazon AGI (trabajo realizado en AWS);
(13) Sudipta Sengupta, Laboratorios de inteligencia artificial de AWS;
(14) Bing Xiang, Goldman Sachs (trabajo realizado en AWS).
Tabla de enlaces
3.1. Notación y 3.2. Inferencia del modelo de lenguaje
3.3. Multi-Consulta, Multi-Encabezado y la Atención Multi-Consulta Generalizada
4. Atención bifurcada consciente del contexto y 4.1. Motivación
4.2. Formulación y 4.3. Complejidad de E/S de memoria
5.1. Comparación de las capacidades de atención multi-cabezal, multi-consulta y multi-grupo
5.2. Latencias de los modelos equivalentes a capacidades
D. Atención Multigrupal Familiar
E. Atención bifurcada consciente del contexto
F. Aplicaciones: Resultados adicionales
G. Compatibilidad con técnicas de decodificación especulativa y de decodificación rápida
Abstracto
En nuestro estudio, presentamos la atención bifurcada , un método desarrollado para la inferencia de modelos de lenguaje en contextos de muestreo por lotes de contexto único. Este enfoque tiene como objetivo reducir los costos de E/S de memoria redundante, un factor significativo en la latencia para tamaños de lote altos y longitudes de contexto largas. La atención bifurcada logra esto dividiendo el mecanismo de atención durante la decodificación incremental en dos operaciones GEMM distintas, centrándose en la caché KV del prellenado y el proceso de decodificación. Este método asegura un cálculo preciso y mantiene la carga computacional habitual (FLOP) de los mecanismos de atención estándar, pero con E/S de memoria reducida. La atención bifurcada también es compatible con el mecanismo de atención de múltiples consultas conocido por la E/S de memoria reducida para la caché KV, lo que permite además un mayor tamaño de lote y longitud de contexto. La eficiencia resultante conduce a una latencia más baja, lo que mejora la idoneidad para aplicaciones en tiempo real, por ejemplo, permitiendo la generación de respuestas masivamente paralelas sin aumentar sustancialmente la latencia, mejorando el rendimiento cuando se integra con técnicas de posprocesamiento como la reclasificación.
1. Introducción
La aparición de los modelos de lenguaje de gran escala (LLM, por sus siglas en inglés) ha marcado el comienzo de una nueva era en el aprendizaje automático, que muestra un rendimiento notable en una amplia gama de tareas (Brown et al., 2020; OpenAI, 2023; Chowdhery et al., 2022; Touvron et al., 2023; Chen et al., 2021; Hoffmann et al., 2022; Li et al., 2022; Microsoft; Amazon, 2022; Nijkamp et al., 2023). A pesar de sus impresionantes capacidades, la implementación de estos modelos a gran escala en aplicaciones prácticas plantea desafíos importantes, en particular en términos de latencia y eficiencia de inferencia. Mejorar estos aspectos es fundamental, ya que influyen directamente en los recursos computacionales necesarios para generar predicciones y permitir la implementación práctica de estos modelos avanzados en diversas industrias.
Un escenario de inferencia particularmente exigente es el muestreo por lotes de un solo contexto, donde el objetivo es generar múltiples finalizaciones a partir de un único contexto. Esta tarea se encuentra comúnmente en numerosas aplicaciones, como herramientas IDE de edición de código que brindan múltiples recomendaciones, o en casos donde se necesita una clasificación entre muchas generaciones para un rendimiento óptimo (a través de métricas de clasificación como probabilidad logarítmica media, votación por mayoría, etc.). La decodificación incremental de dicho escenario de muestreo requiere un uso intensivo de la memoria E/S, lo que se convierte en un cuello de botella de latencia para lotes y longitudes de contexto elevadas.
En este estudio, investigamos dos estrategias compatibles para abordar los desafíos de la E/S de memoria en la inferencia de transformadores: (1) una investigación de consultas múltiples y sus compensaciones, y (2) una técnica novedosa llamada atención bifurcada consciente del contexto.
Nuestra investigación comienza con un análisis de la atención generalizada de múltiples consultas (Ainslie et al., 2023), que incluye múltiples consultas (Shazeer, 2019), así como el mecanismo de atención de múltiples cabezas establecido (Vaswani et al., 2017) para el equilibrio entre rendimiento y latencia. Nuestros hallazgos muestran un escalamiento suave del rendimiento con el aumento del tamaño del modelo para un valor fijo del número de grupos g para múltiples consultas generalizadas[1]. La reducción de g da como resultado un desplazamiento hacia arriba de las curvas de escalamiento de pérdida de validación frente al tamaño del modelo. La relación consistente entre la compresión de caché, el tamaño del modelo y la pérdida de validación nos permite equilibrar la eficiencia de la inferencia con el tamaño del modelo, es decir, nos permite seleccionar una mayor compresión para los casos de uso que requieren alta eficiencia, al mismo tiempo que igualamos el rendimiento de la atención de múltiples cabezas al compensar con un tamaño de modelo más grande.
En segundo lugar, introducimos la atención bifurcada consciente del contexto, una técnica que bifurca cualquier atención en la familia de consultas múltiples generalizadas en contexto y componentes de decodificación durante la decodificación incremental. Dicha bifurcación implica la misma cantidad de FLOP y produce resultados idénticos en comparación con la atención original, pero puede reducir significativamente el costo de E/S de memoria y, por lo tanto, la latencia en escenarios de lotes y longitudes de contexto elevadas. Este enfoque permite la generación de múltiples finalizaciones en tiempo real sin incurrir en muchos costos de latencia adicionales, o permite tamaños de lote mucho más altos que conducen a un mejor rendimiento de clasificación. Por ejemplo, para el modelo multicabezal CodeGen 16B (Nijkamp et al., 2022) con una longitud de contexto de 2k, podemos aumentar el tamaño del lote a 128 con atención bifurcada, en comparación con el tamaño del lote de solo 5 sin ella, lo que da como resultado que el pass@k (Chen et al., 2021) aumente del 59,0 % al 84,6 %, o el pass@top3 a través del log-p medio aumente del 55,2 % al 58,1 %.
Este artículo está disponible en arxiv bajo la licencia CC BY 4.0 DEED.
[1] Los valores más bajos de los grupos de atención g conducen a una mayor compresión de los tensores clave-valor, como en el caso de múltiples consultas donde g = 1, mejorando así la eficiencia de la inferencia y la latencia debido a la caché KV reducida en comparación con el caso de múltiples cabezales donde g = h, el número de cabezales de atención de la consulta.