Authors:
(1) Opher Lieber, with Equal contribution; (2) Barak Lenz, with Equal contribution; (3) Hofit Bata; (4) Gal Cohen; (5) Jhonathan Osin; (6) Itay Dalmedigos; (7) Erez Safahi; (8) Shaked Meirom; (9) Yonatan Belinkov; (10) Shai Shalev-Shwartz; (11) Omri Abend; (12) Raz Alon; (13) Tomer Asida; (14) Amir Bergman; (15) Roman Glozman; (16) Michael Gokhman; (17) Avashalom Manevich; (18) Nir Ratner; (19) Noam Rozen; (20) Erez Shwartz; (21) Mor Zusman; (22) Yoav Shoham.
Table of Links
Abstract
We present Jamba, a new base large language model based on a novel hybrid Transformer-Mamba mixture-of-experts (MoE) architecture. Specifically, Jamba interleaves blocks of Transformer and Mamba layers, enjoying the benefits of both model families. MoE is added in some of these layers to increase model capacity while keeping active parameter usage manageable. This flexible architecture allows resource- and objective-specific configurations. In the particular configuration we have implemented, we end up with a powerful model that fits in a single 80GB GPU. Built at large scale, Jamba provides high throughput and small memory footprint compared to vanilla Transformers, and at the same time state-of-the-art performance on standard language model benchmarks and long-context evaluations. Remarkably, the model presents strong results for up to 256K tokens context length. We study various architectural decisions, such as how to combine Transformer and Mamba layers, and how to mix experts, and show that some of them are crucial in large scale modeling. We also describe several interesting properties of these architectures which the training and evaluation of Jamba have revealed, and plan to release checkpoints from various ablation runs, to encourage further exploration of this novel architecture. We make the weights of our implementation of Jamba publicly available under a permissive license.
Model: https://huggingface.co/ai21labs/Jamba-v0.1
1. Introduction
We introduce Jamba, a new publicly available large language model. Jamba is based on a novel hybrid architecture, which combines Transformer layers [46] with Mamba layers [16], a recent state-space model [17, 18], as well as a mixture-of-experts (MoE) component [13, 41]. Jamba thus combines two orthogonal architectural designs that together give it improved performance and higher throughput, while maintaining a manageable memory footprint. The 7B-based Jamba model (12B active parameters, 52B total available parameters) we are releasing was designed to fit in a single 80GB GPU, but the Jamba architecture supports other design choices, depending on one’s hardware and performance requirements.
The fundamental novelty of Jamba is its hybrid Transformer-Mamba architecture (though see mention below of recent related efforts). Despite the immense popularity of the Transformer as the predominant architecture for language models, it suffers from two main drawbacks. First, its high memory and compute requirements hinders the processing of long contexts, where the key-value (KV) cache size becomes a limiting factor. Second, its lack of a single summary state entails slow inference and low throughput, since each generated token performs a computation on the entire context. In contrast, older recurrent neural network (RNN) models, which summarize an arbitrarily long context in a single hidden state, do not suffer from these limitations. RNN models have their own shortcomings, however. They are costly to train since training cannot be parallelized across time steps. And they struggle with long distance relationships, which the hidden state captures to only a limited extent.
Recent state space models (SSMs) like Mamba are more efficient to train than RNNs and are more capable at handling long distance relationships, but still lag behind the performance of comparably sized Transformer language models. Taking advantage of both model families, Jamba combines Transformer and Mamba layers, at a certain ratio. Varying the ratio of Transformer/Mamba layers allows balancing memory usage, efficient training, and long context capabilities.
A few other recent attempts to combine Attention and SSM modules are worth noting. [50] mixes an S4 layer [17] with a local attention layer, followed by a sequence of local attention layers; it shows experiments with small models and simple tasks. [16] reports that interleaving Mamba and attention layers is only slightly better than pure Mamba in terms of perplexity, with models up to 1.3B parameters. [33] starts with an SSM layer followed by chunk-based Transformers, with models up to 1.3B showing improved perplexity. [12] adds an SSM layer before the self-attention in a Transformer layer, while [38] adds the SSM after the self-attention, both showing improvements on speech recognition. [32] replaces the MLP layers in the Transformer by Mamba layers, and shows benefits in simple tasks. These efforts are different from Jamba both in the particular way in which the SSM component is mixed with the attention one, and in the scale of implementation. Closest are perhaps H3 [14], a specially designed SSM that enables induction capabilities, and a generalization called Hyena [35]. The former proposed a hybrid architecture that replaces the second and middle layers with self-attention, and was implemented with up to 2.7B parameters and 400B training tokens. However, as shown in [16], its perfomance lags that of pure Mamba. Based on Hyena, StripedHyena [36] interleaves attention and SSM layers in a 7B parameter model. However, it lags behind the Attention-only Mistral-7B [22]. All of this renders Jamba the first production-grade Attention-SSM hybrid model. Scaling the hybrid Jamba architecture required overcoming several obstacles, which we dicsuss in Section 6.
Jamba also includes MoE layers [13, 41], which allow increasing the model capacity (total number of available parameters) without increasing compute requirements (number of active parameters). MoE is a flexible approach that enables training extremely large models with strong performance [23]. In Jamba, MoE is applied to some of the MLP layers. The more MoE layers, and the more experts in each MoE layer, the larger the total number of model parameters. In contrast, the more experts we use at each forward pass, the larger the number of active parameters as well as the compute requirement. In our implementation of Jamba, we apply MoE at every other layer, with 16 experts and the top-2 experts used at each token (a more detailed discussion of the model architecture is provided below).
We evaluated our implementation of Jamba on a wide range of benchmarks and found it performs comparably to Mixtral-8x7B [23], which has a similar number of parameters, and also to the larger Llama-2 70B [45]. In addition, our model supports a context length of 256K tokens – the longest supported context length for production-grade publicly available models. On long-context evaluations, Jamba outperformes Mixtral on most of the evaluated datasets. At the same time, Jamba is extremely efficient; for example, its throughput is 3x that of Mixtral-8x7B for long contexts. Moreover, our model fits in a single GPU (with 8bit weights) even with contexts of over 128K tokens, which is impossible with similar-size attention-only models such as Mixtral-8x7B.
Somewhat unusually for a new architecture, we release Jamba (12B active parameters, 52B total available parameters) under Apache 2.0 license: https://huggingface.co/ai21labs/Jamba-v0.1. We do so since we feel that the novel architecture of Jamba calls for further study, experimentation, and optimization by the community. Our design was based on various ablation experiments we conducted to explore the effect of different tradeoffs and design choices, and insights gleaned from those. These ablations were performed at scales of up to 7B parameters, and training runs of up to 250B tokens. We plan to release model checkpoints from these runs.
Important notice: The Jamba model released is a pretrained base model, which did not go through alignment or instruction tuning, and does not have moderation mechanisms. It should not be used in production environments or with end users without additional adaptation.
2. Model Architecture
Jamba is a hybrid decoder architecture that mixes Transformer layers [46] with Mamba layers [16], a recent state-space model (SSM) [17, 18], in addition to a mixture-of-experts (MoE) module [13, 41]. We call the combination of these three elements a Jamba block. See Figure 1 for an illustration.
Combining Transformer, Mamba, and MoE elements allows flexibility in balancing among the sometimes conflicting objectives of low memory usage, high throughput, and high quality. In terms of memory usage, note that comparing the total size of the model parameters can be misleading. In an MoE model, the number of active parameters that participate in any given forward step may be much smaller than the total number of parameters. Another important consideration is the KV cache – the memory required to store the attention keys and values in the context. When scaling Transformer models to long contexts, the KV cache becomes a limiting factor. Trading off attention layers for Mamba layers reduces the total size of the KV cache. Our architecture aims to provide not only a small number of active parameters but also an 8x smaller KV cache compared to a vanilla Transformer. Table 1 compares Jamba with recent publicly available models, showing its advantage in maintaining a small KV cache even with 256K token contexts.
In terms of throughput, with short sequences, attention operations take up a small fraction of the inference and training FLOPS [6]. However, with long sequences, attention hogs most of the compute. In contrast, Mamba layers are more compute-efficient. Thus, increasing the ratio of Mamba layers improves throughput especially for long sequences.
Here is a description of the main configuration, which provides improved performance and efficiency. Section 6 contains results from ablation experiments supporting the design choices.
The basic component is a Jamba block, which may be repeated in sequence. Each Jamba block is a combination of Mamba or Attention layers. Each such layer contains either an attention or a Mamba module, followed by a multi-layer perceptron (MLP). The different possible types of layers are shown in Figure 1(b).[2] A Jamba block contains l layers, which are mixed at a ratio of a : m, meaning a attention layers for every m Mamba layers.
In Jamba, some of the MLPs may be replaced by MoE layers, which helps increase the model capacity while keeping the active number of parameters, and thus the compute, small. The MoE module may be applied to MLPs every e layers. When using MoE, there are n possible experts per layer, with a router choosing the top K experts at each token. In summary, the different degrees of freedom in the Jamba architecture are:
• l: The number of layers.
• a: m: ratio of attention-to-Mamba layers.
• e: how often to use MoE instead of a single MLP.
• n: total number of experts per layer.
• K: number of top experts used at each token.
Given this design space, Jamba provides flexibility in preferring certain properties over others. For example, increasing m and decreasing a, that is, increasing the ratio of Mamba layers at the expense of attention layers, reduces the required memory for storing the key-value cache. This reduces the overall memory footprint, which is especially important for processing long sequences. Increasing the ratio of Mamba layers also improves throughput, especially at long sequences. However, decreasing a might lower the model’s capabilities.
Additionally, balancing n, K, and e affects the relationship between active parameters and total available parameters. A larger n increases the model capacity at the expense of memory footprint, while a larger K increases the active parameter usage and the compute requirement. In contrast, a larger e decreases the model capacity, while decreasing both compute (when K>1) and memory requirements, and allowing for less communication dependencies (decreasing memory transfers as well as inter-GPU communication during expert-parallel training and inference).
Jamba’s implementation of Mamba layers incorporate several normalizations that help stabilize training in large model scales. In particular, we apply RMSNorm [48] in the Mamba layers.
We found that with the Mamba layer, positional embeddings or mechanisms like RoPE [42] are not necessary, and so we do not use any explicit positional information.
Other architecture details are standard, including grouped-query attention (GQA), SwiGLU activation function [6, 40, 45], and load balancing for the MoE [13]. The vocabulary size is 64K. The tokenizer is trained with BPE [15, 29, 39] and each digit is a separate token [6]. We also remove the dummy space used in Llama and Mistral tokenizers for more consistent and reversible tokenization.
This paper is available on arxiv under CC BY-SA 4.0 DEED license.
[2] The figure shows a potential Attention MoE layer, which our architecture does not use, but future variants could.