Autors:
(1) Ben Athiwaratkun, AWS AI Labs;
(2) Sujan Kumar Gonugondla, AWS AI Labs;
(3) Sanjay Krishna Gouda, AWS AI Labs;
(4) Haifeng Qian, AWS AI Labs;
(5) Sanjay Krishna Gouda, AWS AI Labs;
(6) Hantian Ding, AWS AI Labs;
(7) Qing Sun, AWS AI Labs;
(8) Jun Wang, AWS AI Labs;
(9) Jiacheng Guo, AWS AI Labs;
(10 Liangfu Chen, AWS AI Labs;
(11) Parminder Bhatia, GE HealthCare (treball realitzat a AWS);
(12) Ramesh Nallapati, Amazon AGI (treball realitzat a AWS);
(13) Sudipta Sengupta, AWS AI Labs;
(14) Bing Xiang, Goldman Sachs (treball realitzat a AWS).
Taula d'enllaços
3.1. Notació i 3.2. Inferència del model lingüístic
3.3. Multi-consulta, multi-capçalera i l'atenció multi-consulta generalitzada
4. Atenció bifurcada conscient del context i 4.1. Motivació
4.2. Formulació i 4.3. Complexitat d'IO de memòria
5.1. Comparació de les capacitats d'atenció multicap, consulta i atenció multigrup
5.2. Latències de Capacitats-Models Equivalents
E. Atenció bifurcada conscient del context
F. Aplicacions: Resultats addicionals
G. Compatibilitat amb tècniques de descodificació especulativa i de descodificació ràpida
Resum
En el nostre estudi, presentem l'atenció bifurcada , un mètode desenvolupat per a la inferència de models de llenguatge en contextos de mostreig per lots d'un sol context. Aquest enfocament pretén reduir els costos d'E/S de memòria redundants, un factor important de latència per a grans mides de lots i llargues longituds de context. L'atenció bifurcada ho aconsegueix dividint el mecanisme d'atenció durant la descodificació incremental en dues operacions GEMM diferents, centrant-se en la memòria cau KV del preemplenament i el procés de descodificació. Aquest mètode garanteix un càlcul precís i manté la càrrega computacional habitual (FLOP) dels mecanismes d'atenció estàndard, però amb una memòria IO reduïda. L'atenció bifurcada també és compatible amb el mecanisme d'atenció de consulta múltiple conegut per la reducció d'IO de memòria per a la memòria cau KV, permetent encara més una mida de lot i una longitud de context més grans. L'eficiència resultant condueix a una latència més baixa, millorant la idoneïtat per a aplicacions en temps real, per exemple, permetent la generació de respostes massivament paral·leles sense augmentar substancialment la latència, millorant el rendiment quan s'integra amb tècniques de postprocessament com ara el reclassificació.
1. Introducció
L'arribada dels grans models de llenguatge (LLM) ha donat lloc a una nova era d'aprenentatge automàtic, mostrant un rendiment notable en una àmplia gamma de tasques (Brown et al., 2020; OpenAI, 2023; Chowdhery et al., 2022; Touvron et al., 2023; Chen et al., 2021; Hoffmann et al., 2021; al., 2022; Amazon, 2022; Nijkamp et al., 2023). Malgrat les seves impressionants capacitats, el desplegament d'aquests models a gran escala en aplicacions pràctiques planteja reptes importants, especialment en termes de latència i eficiència d'inferència. Millorar aquests aspectes és fonamental, ja que influeixen directament en els recursos computacionals necessaris per generar prediccions i permetre la implementació pràctica d'aquests models avançats en diverses indústries.
Un escenari d'inferència especialment exigent és el mostreig per lots d'un sol context, on l'objectiu és generar múltiples finalitzacions a partir d'un sol context. Aquesta tasca es troba habitualment en nombroses aplicacions, com ara les eines IDE d'edició de codi que proporcionen múltiples recomanacions, o en els casos en què es necessita una classificació entre moltes generacions per obtenir un rendiment òptim (mitjançant mètriques de classificació com la probabilitat mitjana del registre, la votació majoritària, etc.). La descodificació incremental d'aquest escenari de mostreig requereix una intensitat d'E/S de memòria, que es converteix en un coll d'ampolla de latència per a lots elevats i longituds de context.
En aquest estudi, investiguem dues estratègies compatibles per abordar els reptes d'IO de memòria en la inferència de transformadors: (1) una investigació de la consulta múltiple i les seves compensacions, i (2) una nova tècnica anomenada atenció bifurcada conscient del context.
La nostra investigació comença amb una anàlisi de l'atenció multiconsulta generalitzada (Ainslie et al., 2023), que inclou la consulta múltiple (Shazeer, 2019), així com el mecanisme d'atenció multicaps establert (Vaswani et al., 2017) per al rendiment i la latència. Els nostres resultats mostren una escala de rendiment suau amb l'augment de la mida del model per a un valor fix del nombre de grups g per a una consulta múltiple generalitzada [1]. La baixada de g provoca un desplaçament a l'alça de la pèrdua de validació respecte a les corbes d'escala de la mida del model. La relació coherent entre la compressió de la memòria cau, la mida del model i la pèrdua de validació ens permet compensar l'eficiència de la inferència amb la mida del model, és a dir, ens permet seleccionar una compressió més alta per a casos d'ús que requereixen una alta eficiència, alhora que igualem el rendiment de l'atenció multi-capçal compensant amb una mida de model més gran.
En segon lloc, introduïm l'atenció bifurcada conscient del context, una tècnica que bifurca qualsevol atenció de la família de consultes múltiples generalitzades en components de context i descodificació durant la descodificació incremental. Aquesta bifurcació implica el mateix nombre de FLOP i produeix resultats idèntics en comparació amb l'atenció original, però pot reduir significativament el cost d'IO de memòria i, per tant, la latència en escenaris de llargada de lot i context elevats. Aquest enfocament permet la generació de múltiples finalitzacions en temps real sense incórrer en molts costos de latència addicionals, o permet mides de lots molt més grans que permeten un millor rendiment de classificació. Per exemple, per al model de capçal múltiple CodeGen 16B (Nijkamp et al., 2022) amb una longitud de context de 2k, podem augmentar la mida del lot a 128 amb atenció bifurcada, en comparació amb la mida del lot de només 5 sense, donant lloc a que el pass@k (Chen et al., 2021) augmenti des del 59.0% fins al 59. passant del 55,2% al 58,1%.
Aquest document està disponible a arxiv sota la llicència CC BY 4.0 DEED.
[1] Els valors més baixos dels grups d'atenció g condueixen a una compressió més alta dels tensors de valors clau, com en el cas de consulta múltiple on g = 1, per tant, millora l'eficiència i la latència d'inferència a causa de la reducció de la memòria cau de KV en comparació amb el cas de múltiples capçals on g = h, el nombre de capçals d'atenció de la consulta.