Right now, AI is eating the world.
And by AI, I mean Transformers. Practically all the big breakthroughs in AI over the last few years are due to Transformers.
Mamba, however, is one of an alternative class of models called State Space Models (SSMs). Importantly, for the first time, Mamba promises similar performance (and crucially similar scaling laws) as the Transformer whilst being feasible at long sequence lengths (say 1 million tokens). We achieve this long context by removing the “quadratic bottleneck” in the Attention Mechanism. Mamba also runs fast - like “up to 5x faster than Transformer fast”^{1}.
Gu and Dao, the Mamba authors write:
Mamba enjoys fast inference and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.
Here we’ll discuss:
We’re very much in the Transformer-era of history. ML used to be about detecting cats and dogs. Now, with Transformers, we’re generating human-like poetry, coding better than the median competitive programmer, and solving the protein folding problem.
But Transformers have one core problem. In a transformer, every token can look back at every previous token when making predictions. For this lookback, we cache detailed information about each token in the so-called KV cache.
This pairwise communication means a forward pass is O(n²) time complexity in
training (the dreaded quadratic bottleneck
) and each new token generated
autoregressively takes O(n) time. That is to say, as the context gets larger,
the model gets slower.
To add insult to injury, storing this KV cache requires O(n) space. The fateful
CUDA OOM
error looms large as the memory footprint balloons. If space were the
only issue, we might just add more GPUs but with latency growing
quadratically… perhaps not.
On the margin, we can mitigate the quadratic bottleneck with techniques like Sliding Window Attention or clever CUDA optimisations like FlashAttention. But ultimately, for super long context windows (like a chatbot which remembers every conversation you’ve shared), we need a different approach.
Fundamentally, all good ML architecture backbones have components for two important operations:
In transformers, this is Attention (communication
) and MLPs
(computation
). We improve transformers by optimising these two
operations^{2}.
We would like to replace the Attention component ^{3} with some other
method for communicating between tokens. Mamba uses the Control Theory
inspired SSM for Communication
and keeps MLP-style projections for
Computation
.
Like a Transformer made up of stacked transformer blocks, Mamba is made up of stacked Mamba blocks as above.
We would like to understand and motivate the choice of the SSM for sequence transformations.
Imagine we’re building a Temple Run agent ^{4}. It chooses if the runner should move left or right at any time.
To successfully pick the correct direction, we need information about our
surroundings. Let’s call the collection of relevant information the state
.
Here the state likely includes your current position and velocity, the position
of the nearest obstacle, weather conditions, etc.
Claim 1: if you know the current state of the world and how the world is evolving, then you can use this to determine the direction to move.
Note that you don’t need to look at the whole screen all the time. You can figure out what will happen to most of the screen by noting that as you run, the obstacles move down the screen. You only need to look at the top of the screen to understand the new information and then simulate the rest.
This lends itself to a natural formulation. Let h be the hidden state, relevant knowledge about the world. Also let x be the input, the observation that you get each time. h’ then represents the derivative of the hidden state, i.e. how the state is evolving. We’re trying to predict y, the optimal next move (right or left).
Now, Claim 1 states that
from the hidden state h, h’, and the new observation x, you can figure out y
.
More concretely, h, the state, can be represented as a differential equation (Eq 1a):
\[h’(t) = \mathbf{A}h(t) + \mathbf{B}x(t)\]Knowing h allows you to determine your next move y (Eq 1b):
\[y(t) = \mathbf{C}h(t) + \mathbf{D}x(t)\]The system evolves as a function of the current state and new observations. A small new observation is enough because we can determine most of the state by applying known state dynamics to the previous state. That is, most of the screen isn’t new, it’s just the natural downward movement of the previous state. Fully knowing the state would allow us to pick the best next move, y.
You can learn a lot about the system dynamics by observing the top of the screen - if it’s moving faster, we can infer the whole screen is and the game is speeding up^{5}. In this way, even if we start off knowing nothing about the game except our limited observation, pretty soon we could understand the whole screen.
Here, state refers to the variables that, when combined with the input variables, fully determine the future system behaviour. In theory, once we have the state, there’s nothing else we need to know about the past to predict the future. With this choice of state, the system is converted to a Markov Decision Process. Ideally, the state is a fairly small amount of information which captures the essential properties of the system. That is, the state is a compression of the past ^{6}
Okay, great! So, given some state and input observation, we have an autoregressive-style system to determine the next action. Amazing!
In practice though, there’s a little snag here. We’re modelling time as continuous. But in real life, we get new inputs and take new actions at discrete time steps ^{7}.
We would like to convert this continuous-time differential equation into a
discrete-time difference equation. This conversion process is known as
discretisation
. Discretisation is a well-studied problem in the literature.
Mamba uses the Zero-Order Hold
(ZOH) discretisation^{8}. To give an idea of what’s happening morally,
consider a naive first-order approximation^{9}.
From Equation 1a, we have
\[h’(t) = \mathbf{A}h(t) + \mathbf{B}x(t)\]And for small ∆,
\[h’(t) \approx \frac{h(t+\Delta) - h(t)}{\Delta}\]by the definition of the derivative.
We let:
\[h_t = h(t)\]and
\[h_{t+1} = h(t + \Delta)\]and substitute into Equation 1a giving:
\[h_{t+1} - h_t \approx \Delta (\mathbf{A}h_t + \mathbf{B}x_t)\] \[\Rightarrow h_{t+1} \approx (I + \Delta \mathbf{A})h_t + (\Delta \mathbf{B})x_t\]Hence, after renaming the coefficients and relabelling indices, we have the discrete representations:
If you’ve ever looked at an RNN before ^{10} and this feels familiar - trust your instincts:
We have some input x, which is combined with the previous hidden state by some transform to give the new hidden state. Then we use the hidden state to calculate the output at each time step.
Now, we can interpret the A, B, C, D matrices more intuitively:
Additionally, ∆ has a nice interpretation - it’s the step size, or what we might
call the linger time
or the dwell time
. For large ∆, you focus more on that
token; for small ∆, you skip past the token immediately and don’t include it
much in the next state.
And that’s it! That’s the SSM, our ~drop-in replacement for Attention
(Communication
) in the Mamba block. The Computation
in the Mamba
architecture comes from regular linear projections, non-linearities, and local
convolutions - the regular ML building blocks we know and love!
Okay great, that’s the theory - but does this work? Well…
At WWDC ‘97, Steve Jobs famously noted that “focusing is about saying no”. Focus is ruthless prioritisation. It’s common to think about Attention positively as choosing what to notice. In the Steve Jobs sense, we might instead frame Attention negatively as choosing what to discard.
There’s a classic intuition pump in Machine Learning known as the Cocktail Party Problem ^{13}. Imagine a party with dozens of simultaneous loud conversations:
Question:
How do we recognise what one person is saying when others are talking at the same time? ^{14}
Answer:
The brain solves this problem by focusing your “attention” on a particular stimulus and hence drowning out all other sounds as much as possible.
Transformers use Dot-Product Attention to focus in on the most relevant tokens. A big reason Attention is so great is that you have the potential to look back at everything that ever happened in its context. This is like photographic memory when done right. ^{15}
Transformers (🤖) are extremely effective. But they aren’t very efficient. They store everything from the past so that they can look back at tokens with theoretically perfect recall.
Traditional RNNs (🔁) are the opposite - they forget a lot, only recalling a small amount in their hidden state and discarding the rest. They are very efficient - their state is small. Yet they are less effective as discarded information cannot be recovered.
We’d like something closer to the Pareto frontier of the effectiveness/efficiency tradeoff. Something that’s more effective than traditional RNNs and more efficient than transformers.
SSMs are as efficient as RNNs, but we might wonder how effective they are. After all, it seems like they would have a hard time discarding only unnecessary information and keeping everything relevant. If each token is being processed the same way, applying the same A and B matrices as if in a factory assembly line for tokens, there is no context-dependence. We would like the forgetting and remembering matrices (A and B respectively) to vary and dynamically adapt to inputs.
Selectivity allows each token to be transformed into the state in a way that is unique to its own needs. Selectivity is what takes us from vanilla SSM models (applying the same A (forgetting) and B (remembering) matrices to every input) to Mamba, the Selective State Space Model.
In regular SSMs, A, B, C and D are learned matrices - that is
\(\mathbf{A} = \mathbf{A}_{\theta}\) etc. (where θ represents the learned parameters)
With the Selection Mechanism in Mamba, A, B, C and D are also functions of x. That is \(\mathbf{A} = \mathbf{A}_{\theta(x)}\) etc; the matrices are context dependent rather than static.
Making A and B functions of x allows us to get the best of both worlds:
The Mamba paper authors write:
The efficiency vs. effectiveness tradeoff of sequence models is characterized by how well they compress their state: efficient models must have a small state, while effective models must have a state that contains all necessary information from the context. In turn, we propose that a fundamental principle for building sequence models is selectivity: or the context-aware ability to focus on or filter out inputs into a sequential state. In particular, a selection mechanism controls how information propagates or interacts along the sequence dimension.
Humans (mostly) don’t have photographic memory for everything they experience within a lifetime - or even within a day! There’s just way too much information to retain it all. Subconsciously, we select what to remember by choosing to forget, throwing away most information as we encounter it. Transformers (🤖) decide what to focus on at recall time. Humans (🧑) also decide what to throw away at memory-making time. Humans filter out information early and often.
If we had infinite capacity for memorisation, it’s clear the transformer approach is better than the human approach - it truly is more effective. But it’s less efficient - transformers have to store so much information about the past that might not be relevant. Transformers (🤖) only decide what’s relevant at recall time. The innovation of Mamba (🐍) is allowing the model better ways of forgetting earlier - it’s focusing by choosing what to discard using Selectivity, throwing away less relevant information at memory-making time^{16}.
Applying the Selection Mechanism does have its gotchas though. Non-selective SSMs (i.e. A,B not dependent on x) are fast to compute in training. This is because the component of \(y_t\) which depends on \(x_i\) can be expressed as a linear map, i.e. a single matrix that can be precomputed!
For example (ignoring the D component, the skip connection):
\[y_2 = \mathbf{C}\mathbf{B}x_2 + \mathbf{C}\mathbf{A}\mathbf{B}x_1 + \mathbf{C}\mathbf{A}\mathbf{A}\mathbf{B}x_0\]If we’re paying attention, we might spot something even better here - this expression can be written as a convolution. Hence we can apply the Fast Fourier Transform and the Convolution Theorem to compute this very efficiently on hardware as in Equation 3 below.
Unfortunately, with the Selection Mechanism, we lose the convolutional form. Much attention is given to making Mamba efficient on modern GPU hardware using similar hardware optimisation tricks to Tri Dao’s Flash Attention ^{17}. With the hardware optimisations, Mamba is able to run faster than comparably sized Transformers.
The Mamba authors write, “the efficiency vs. effectiveness tradeoff of sequence models is characterised by how well they compress their state”. In other words, like in political economy^{18}, the fundamental problem is how to manage the state.
🔁 Traditional RNNs are anarchic
They have a small, minimal state. The size of the state is bounded. The compression of state is poor.
🤖 Transformers are communist
They have a maximally large state. The “state” is just a cache of the entire history with no compression. Every context token is treated equally until recall time.
🐍Mamba has a compressed state
…but it’s selective about what goes in. Mamba says we can get away with a small state if the state is well focused and effective ^{19}.
The upshot is state representation is critical. A smaller state is more efficient; a larger state is more effective. The key is to selectively and dynamically compress data into the state. Mamba’s Selection Mechanism allows for context-dependent reasoning, focusing and ignoring. For both performance and interpretability, understanding the state seems to be very useful.
How do Transformers know anything? At initialisation, a transformer isn’t very smart. It learns in two ways:
Models learn from their training data. This is a kind of lossy compression of input data into the weights. We can think of the effect of pretraining data on the transformer kinda like the effect of your ancestor’s experiences on your genetics - you can’t recall their experiences, you just have vague instincts about them ^{20}.
Transformers use their context as short-term memory, which they can recall with ~perfect fidelity. So we get In-Context Learning, e.g. using induction heads to solve the Indirect Object Identification task, or computing Linear Regression.
Note that Transformers don’t filter their context at all until recall time. So if we have a bunch of information we think might be useful to the Transformer, we filter it outside the Transformer (using Information Retrieval strategies) and then stuff the results into the prompt. This process is known as Retrieval Augmented Generation (RAG). RAG determines relevant information for the context window of a transformer. A human with the internet is kinda like a RAG system - you still have to know what to search but whatever you retrieve is as salient as short-term memory to you.
Training Data acts similarly for Mamba. However, the lines are slightly blurred for in-context data and retrieval. In-context data for Mamba is compressed/filtered similar to retrieval data for transformers. This in-context data is also accessible for look-up like for transformers (although with somewhat lower fidelity).
Transformer context is to Mamba states what short-term is to long-term memory. Mamba doesn’t just have “RAM”, it has a hard drive^{21} ^{22}.
Currently, we often use RAG to give a transformer contextual information.
With Mamba-like models, you could instead imagine having a library of states created by running the model over specialised data. States could be shared kinda like LoRAs for image models.
For example, I could do inference on 20 physics textbooks and, say, 100 physics questions and answers. Then I have a state which I can give to you. Now you don’t need to add any few-shot examples; you just simply ask your question. The in-context learning is in the state.
In other words, you can drag and drop downloaded states into your model, like literal plug-in cartridges. And note that “training” a state doesn’t require any backprop. It’s more like a highly specialised one-pass fixed-size compression algorithm. This is unlimited in-context learning applied at inference time for zero-compute or latency. ^{23}
The structure of an effective LLM call goes from…
…for Transformers, to simply…
…for Mamba.
This is cheaper and faster than few-shot prompting (as the state is infinitely reusable without inference cost). It’s also MUCH cheaper than finetuning and doesn’t require any gradient updates. We could imagine retrieving states in addition to context.
Transformer interpretability typically involves:
Most of the ablations that we would like to do for Mamba are still valid, but understanding token communication (1) is now more nuanced. All information moves between tokens via hidden states instead of the Attention Mechanism which can “teleport” information from one sequence position to another.
For understanding in-context learning (ICL) tasks with Mamba, we will look to intervene on the SSM state. A classic task in-context learning task is Indirect Object Identification in which a model has to finish a paragraph like:
Then, Shelby and Emma had a lot of fun at the school. [Shelby/Emma] gave an apple to [BLANK]
The model is expected to fill in the blank with the name that is not repeated in
the paragraph. In the chart below we can see that information is passed from the
[Shelby/Emma]
position to the final position via the hidden state (see the two
blue lines in the top chart).
Since it’s hypothesised that much of In-Context Learning in Transformers is downstream of more primitive sequence position operations (like Induction Heads), Mamba being able to complete this task suggests a more general In-Context Learning ability.
Mamba-like models are likely to excel in scenarios requiring extremely long context and long-term memory. Examples include:
An illustrative example is agents with long-term goals.
Suppose you have an agent interacting with the world. Eventually, its experiences become too much for the context window of a transformer. The agent then has to compress or summarise its experiences into some more compact representation.
But how do you decide what information is the most useful as a summary? If the task is language, LLMs are actually fairly good at summaries - okay, yeah, you’ll lose some information, but the most important stuff can be retained.
However, for other disciplines, it might not be clear how to summarise. For example, what’s the best way to summarise a 2 hour movie? ^{24}. Could the model itself learn to do this naturally rather than a hacky workaround like trying to describe the aesthetics of the movie in text?
This is what Mamba allows. Actual long-term memory. A real state where the model learns to keep what’s important. Prediction is compression - learning what’s useful to predict what’s coming next inevitably leads to building a useful compression of the previous tokens.
The implications for Assistants are clear:
Your chatbot co-evolves with you. It remembers.
One reason for positive updates in existential risk from AGI is Language Models. Previously, Deep-RL agents trained via self-play looked set to be the first AGIs. Language models are inherently much safer since they aren’t trained with long-term goals. ^{25}
The potential for long-term sequence reasoning here brings back the importance of agent-based AI safety. Few agent worries are relevant to Transformers with an 8k context window. Many are relevant to systems with impressive long-term memories and possible instrumental goals.
The Mamba authors show that there’s value in combining Mamba’s long context with the Transformer’s high fidelity over short sequences. For example, if you’re making long videos, you likely can’t fit a whole movie into a Transformer’s context for attention ^{26}. You could imagine having Attention look at the most recent frames for short-term fluidity and an SSM for long-term narrative consistency ^{27}.
This isn’t the end for Transformers. Their high effectiveness is exactly what’s needed for many tasks. But now Transformers aren’t the only option. Other architectures are genuinely feasible.
So we’re not in the post-Transformer
era. But for the first time, we’re living
in the post-only-Transformers
era ^{28}. And this blows the
possibilities wide open for sequence modelling with extreme context lengths and
native long-term memory.
Two ML researchers, Sasha Rush (HuggingFace, Annotated Transformer, Cornell Professor) and Jonathan Frankle (Lottery Ticket Hypothesis, MosaicML, Harvard Professor), currently have a bet here.
Currently Transformers are far and away in the lead. With 3 years left, there’s now a research direction with a fighting chance.
All that remains to ask is: Is Attention All We Need?
Join the discussion on Hacker News here
Thanks to Gonçalo for reading an early draft, Jaden for the nnsight library used for the Interpretability analysis and Tessa for Mamba patching visualisations.
Also see: Mamba paper, Mamba Python code, Annotated S4, Nathan Labenz podcast
see Figure 8 in the Mamba paper. ↩
And scaling up with massive compute. ↩
More specifically the scaled dot-product Attention popularised by Transformers ↩
For people who don’t see Temple Run as the cultural cornerstone it is 🤣 Temple Run was an iPhone game from 2011 similar to Subway Surfer ↩
Here we assume the environment is sufficiently smooth. ↩
One pretty important constraint for this to be efficient is that we don’t allow the individual elements of the state vector to interact with each other directly. We’ll use a combination of the state dimensions to determine the output but we don’t e.g. allow the velocity of the runner and the direction of the closest obstacle (or whatever else was in our state) to directly interact. This helps with efficient computation and we achieve this practically by constraining A to be a diagonal matrix. ↩
Concretely consider the case of Language Models - each token is a discrete step ↩
ZOH also has nice properties for the initialisations - we want A_bar to be close to the identity so that the state can be mostly maintained from timestep to timestep if desired. ZOH gives A_bar as an exponential so any diagonal element initialisations close to zero give values close to 1 ↩
This is known as the Euler discretisation in the literature ↩
It’s wild to note that some readers might not have, we’re so far into the age of Attention that RNNs have been forgotten! ↩
B is like the Query (Q) matrix for Transformers. ↩
C is like the Output (O) matrix for Transformers. ↩
Non-alcoholic options also available! ↩
Especially as all voices roughly occupy the same space on the audio frequency spectrum Intuitively this seems really hard! ↩
Note that photographic memory doesn’t necessarily imply perfect inferences from that memory! ↩
To be clear, if you have a short sequence, then a transformer should theoretically be a better approach. If you can store the whole context, then why not!? If you have enough memory for a high-resolution image, why compress it into a JPEG? But Mamba-style architectures are likely to hugely outperform with long-range sequences. ↩
More details are available for engineers interested in CUDA programming - Tri’s talk, Mamba paper section 3.3.2, and the official CUDA code are good resources for understanding the Hardware-Aware Scan ↩
or in Object Oriented Programming ↩
Implications to actual Political Economy are left to the reader but maybe Gu and Dao accidentally solved politics!? ↩
This isn’t a perfect analogy as human evolution follows a genetic algorithm rather than SGD. ↩
Albeit a pretty weird hard drive at that - it morphs over time rather than being a fixed representation. ↩
As a backronym, I’ve started calling the hidden_state the state space
dimension (or selective state dimension) which shortens to SSD
, a nice
reminder for what this object represents - the long-term memory of the
system. ↩
I’m thinking about this similarly to the relationship between harmlessness finetuning and activation steering. State swapping, like activation steering, is an inference time intervention giving comparable results to its train time analogue. ↩
This is a very non-trivial problem! How do human brains represent a movie internally? It’s not a series of the most salient frames, nor is it a text summary of the colours, nor is it a purely vibes-based summary if you can memorise some lines of the film. ↩
They’re also safer since they inherently understand (though don’t necessarily embody) human values. It’s not all clear that how to teach an RL agent human morality. ↩
Note that typically an image (i.e. a single frame) counts as >196 tokens, and movies are typically 24 fps so you’ll fill a 32k context window in 7 seconds 🤯 ↩
Another possibility that I’m excited about is applying optimisation pressure to the state itself as well as the output to have models that respect particular use cases. ↩
This is slightly hyperbolic, the TS-Mixer for time series, Gradient Boosting Trees for tabular data and Graph Neural Networks for weather prediction exist and are currently used, but these aren’t at the core of AI ↩
Since the infamous BitTorrent link launch of Mixtral, Mistral’s Mixture of Expert (MoE) model, there’s been renewed attention^{1} paid to MoE models.
This week, Mistral released the paper accompanying the model. This feels like a great time to dig into the details of the Mixtral model and the impact that it’s having on the MoE and LLM communities so far.
We discussed the intuition behind MoE models in An Analogy for Understanding Mixture of Expert Models:
In Sparse Mixture of Experts (MoEs), we swap out the
MLP layers
of the vanilla transformer for anExpert Layer
. The Expert Layer is made up of multiple MLPs called “Experts”. For each input we select one expert to send that input to. In this way, each token has different parameters applied to it. A dynamic routing mechanism decides how to map tokens to Experts.
This approach gives models more parameters ^{2} without requiring more compute or latency for each forward pass. MoE models also typically have better sample efficiency - that is, their performance improves much faster than dense transformers in training, when given the same amount of compute. This isn’t quite a free lunch because it requires more memory to store the model for inference, but, if you have enough memory, it’s pretty great.
Mixtral 8x7B has the backbone of Mistral-7B (their previous model). As with Mistral-7B, Mixtral uses Group Query Attention and Sliding Window Attention. The main changes are a 32k context window out of the box and replacing the Feed Forward Networks (FFNs) with Mixture of Expert (MoE) layers.
Mixtral opts for an MoE layer with 8 FFN experts
which are sparsely activated
by choosing the top 2
at each layer. ^{3}
Having 8 experts
means that where the original Mistral had a single FFN per
transformer block, Mixtral has 8 separate FFNs. ^{4}
Rather than each token rather than being processed by all the parameters, a
routing network dynamically selects the top 2
experts for each token depending
on the content of the token itself. Hence, though the total parameter count is
47B, the “active” parameter count (i.e. the number of parameters used for each
forward pass) comes in at 13B. ^{5}
Succinctly an MoE layer is given as:
\[\displaystyle \sum_{i=0}^{n-1}G(x)_i \cdot E_i(x),\]where G is a gating function which is 0 everywhere except at 2 indices and where each \(E_i\) is a single expert FFN. Note that in the above formula since most of the entries of the sum are zeros (as G(x) is zero for most i), we only have to compute some of the \(E_is\) rather than all of them. This is where MoEs have computational efficiency advantages over using bigger models or using an ensemble of models.
There is an MoE layer in each of the transformer blocks (32 in this case) and hence we do this routing procedure 32 times for each forward pass. In a traditional ensemble model, N (8 in this case) models have their predictions averaged, so there are 8 token paths. We can compare the number of possible paths that each token could take in an MoE to these ensemble methods:
At each layer we choose 2 of the 8 experts to process our token. There are \(\binom{8}{2}\) = 28 ways to do this. And this happens at each of the 32 layers giving \(28^{32}\) possible paths overall, which is huge^{6}! 🤯 The variety of possible paths here points towards increasingly Adaptive Computation in models. In Adaptive Computation, we consider models which handle different tokens with different parameters and different amounts of compute.
Up until now there have been a two barriers to truly performant and stably trainable MoEs:
MoE models have an inherently discrete step, the hard routing, and this typically harms the gradient flow. Typically we want fully differentiable functions for backprop and MoEs aren’t even continuous! Considering mathematically plausible approximations to the true gradient can hugely improve MoE training. Recent approaches like Sparse Backpropagation and Soft MoE for encoders provide better gradient flow and hence more performant models.
Compared to their FLOP-class, MoEs are larger models. Their size means that there are real benefits to effective parallelisation and minimising communication costs. Many frameworks such as DeepSpeed MoE now support MoE training in a fairly hardware efficient way.
Having overcome both of these issues, we’re now ready to use MoEs more in practise.
The Mixtral base model outperforms popular (and larger) models like Llama 2 70B, Gemini Pro and GPT-3.5 on most benchmarks. Note that these models are not only larger in total parameter count but are also larger in active parameter count too!
At the time of writing, Mixtral is the best open-source model and the 3rd best Chat model, only beaten by GPT-4 and Claude 2.0 variants.
Mixtral shows impressive use of its whole 32k context window. The model has relatively good recall even for mid-context tokens.
Along with the base model, Mistral also released Instruction Fine-Tuned Chat and Assistant models. For alignment, they opted for Direct Preference Optimisation (DPO) which is proving to be a powerful and less finicky alternative to the traditional RLHF. ^{7}
One hypothesis about MoEs is that some experts might specialise in a particular domain (e.g. mathematics, biology, code, poetry etc.). This hypothesis is an old one which has consistently been shown to be mistaken in the literature. Here, the authors confirm, as in previous MoE papers, there is little difference in the distribution of experts used for different domains (although they report being surprised by this finding!). Often experts seem to specialise syntactically (e.g. an expert for punctuation or whitespace), rather than semantically (an expert for neuroscience). ^{8}
Although the distribution of experts is fairly uniform overall, interestingly two adjacent tokens are much more likely to be processed by the same expert, than we might naively predict. In other words, once an expert sees one token, it’s quite likely to also see the next one - experts like to alley-oop themselves! 🏀
This recent paper details ways to exploit this alley-oop property by caching the recently used expert weights in fast memory.
As I’ve noted previously, I’m excited about the explicit modularity in MoE models for increased interpretability.
There’s little information in the paper about expert balancing techniques. Many different auxiliary losses have been proposed for expert balancing and it would be cool to see which loss function Mistral found to work well at this scale.
The authors are also quite hush about the pretrain, instruction or feedback datasets used to train the model. Given the impressive performance, it’s quite likely that there’s some secret sauce in the dataset compilation and filtering. It seems increasingly likely that data will be a moat for Foundation Model providers ^{9}.
MoEs win by having increased performance with faster inference. Founder Sharif Shameem writes, “The Mixtral MoE model genuinely feels like an inflection point — a true GPT-3.5 level model that can run at 30 tokens/sec on an M1 MacBook Pro. Imagine all the products now possible when inference is 100% free and your data stays on your device!”
Indeed since the launch of Mixtral, it’s been used in many applications from the enterprise to local chatbots to DIY Home Assistants à la Siri.
As many people use MoE models on-device for the first time, I expect that we will start to see more methods which speed up MoE inference. The Fast MoE Inference paper and MoE specific quantisation like QMoE are all great steps in this direction.
In particular, Quantization can be thought of as storing a model compressed like we do for audio in MP3s. We degrade the quality model slightly and get massive decreases in the memory that it requires. We can typically quantise MoEs even more aggressively than dense models and retain strong performance.
Mistral was only started a matter of months ago with a super lean team and is already SOTA for Open Source models. This is impressive from their team but it may also suggest that Foundation Models are being commodified real quick.
Originally Mistral were offering Mixtral behind their API for \$1.96 per million tokens. Considering GPT-4 is $10-30 at the time of writing, this seemed fair for a hosted API. Within days different inference providers undercut Mistral significantly:
Last week @MistralAI launched pricing for the Mixtral MoE: $2.00~ / 1M tokens.
— JJ — oss/acc (@JosephJacks_) December 15, 2023
Hours later @togethercompute took the weights and dropped pricing by 70% to $0.60 / 1M.
Days later @abacusai cut 50% deeper to $0.30 / 1M.
Yesterday @DeepInfra went to $0.27 / 1M.
Who’s next ??? 📉
There was even one provider who was giving away tokens for free. I know a race to the bottom when I see one…
The consumer/developer is truly winning here but it reiterates the point that Foundation Model companies should expect the value of tokens to fall dramatically. Competition is for Losers, as Peter Thiel might say; it’s very possible to compete away all the profits to zero. ^{10}. It increasingly looks like most of the value captured from an LLM business perspective will likely be in the application layer (e.g. Perplexity, Copilot) and the infrastructure layer (e.g. AWS/Azure).
Mixtral is a huge win for the scientific and interpretability communities. We now finally have a model which is comfortably better than GPT3.5 and whose weights are freely available to researchers.
In addition, given Mixtral shares the same backbone as the previous Mistral 7B, it seems plausible some weights were re-used as initialisations for Mixtral. This approach is known in the literature as Sparse Upcycling. If Sparse Upcycling works, this suggests that the compute required to make great MoE models might be much less than previous thought. Researchers can take advantage of existing models like Llama 2 etc. rather than having to pretrain entirely from scratch, which completely changes which projects are feasible for academics and the GPU-poor.
“In 2012 we were detecting cats and dogs and in 2022 we were writing human-like poetry, generating beautiful and novel imagery, solving the protein folding problem and writing code. Why is that?”
Arthur Mensch, Mistral co-founder, suggests most of the reason is “the free flow of information. You had academic labs [and] very big industry labs communicating all the time about their results and building on top of others’ results. That’s the way we [significantly improved] the architecture and training techniques. We made everything work as a community”.
We’re not at the end of the ML story just yet. There’s still science to be done and inventions to be discovered so we still need the free flow of information.
In this house we love Open Source models and papers. 🤗
Expect MoEs to become even more important for 2024. The age of Adaptive Computation is here.
If I may ↩
which is more knowledge in some sense ↩
This is a slight reversal of recent work as many papers had followed the Switch Transformer in only choosing the top 1 expert per layer. We expect that choosing 2 experts allows for more expressivity, more stable training and better gradient flow which is traded off against increased parallel computation in each forward pass. ↩
Early MoEs like the Switch Transformer were using 100s of experts per layer. This always seemed a little excessive and working on these models, a good heuristic for choosing the expert number hyperparameter is either the number of experts that will fit into your single GPU memory if you’re mostly doing single batch inference or the number of GPUs that you could do expert parallelism on at inference time if you’re running a high bandwidth API. With this in mind 8 experts seems like a nice middle ground right now for users of an open-source product. ↩
This is slightly less than 8x7 =56 total parameters as because attention and embedding parameters are not duplicated). ↩
In fact this is quite the understatement, all of these paths are weighted according to the router logits so there’s even more nuance than this in the possible paths that tokens can take. ↩
It may also be the case that experts indeed do specialise semantically but that their natural semantic specialisation is not very clear to human researchers ↩
at least for companies that don’t produce applications built on top of the models. ↩
Sam Altman has colourfully referred to this as the marginal cost of intelligence going to zero ↩
Machine learning is built of matrix algebra. Einstein summation notation (or
einsum
for short) makes matrix operations more intuitive and readable.
As you may know, the matrix multiplication that you learned in high school…
Can be written algebraically as:
\[A_{ik} = \sum_j B_{ij} C_{jk}\]In other words in order to get the (1,2) element of A we calculate:
\[A_{1,2} = \sum_j B_{1j} C_{j2}\]i.e. take the dot product of the 1st row of B with the 2nd column of C.
In Einsum notation, to avoid having so many sigmas (\(\sum\)) flying around we adopt the convention that any indices that appear more than once are being summed over. Hence:
\[A_{ik} = \sum_j B_{ij} C_{jk}\]can be written more simply as…
\[A_{ik} = B_{ij} C_{jk}\]Both torch and numpy have einsum packages to allow you to use einsum notation for matrix operations. For example, we can write the above matrix multiplication in torch as:
import torch as t
A = t.einsum("ij,jk->ik", B, C)
The convention is that if a dimension only appears on the left side of the einsum then it’s summed over. So in the above we’re summing over the j dimension and keeping the i and k dimensions. That’s our classic matrix multiplication written in torch einsum notation^{1}.
Great!
One issue when using torch.einsum though is that it’s not necessarily super clear what each letter means:
To get around this ambiguity, it’s common to see PyTorch code where in the docstring each of the letters is defined. This isn’t a very natural pattern - it’s like if all of your variable names in code had to be single letters and you had another file which would act as a dictionary for what each letter actually meant! shudders.
One of the most useful lines of the Zen of Python
is
Explicit is better than Implicit
. Following this principle, we would like to
be able to write the variable names in the einsum string itself. Without this,
it’s harder to read and means you’re always looking back when trying to
understand or debug the code.
Einops is a tensor manipulation package that can be used with PyTorch, NumPy, Tensorflow and Jax. It offers a nice API but we’ll focus on einsums which we can now use with full variable names rather than single letters! It makes your ML code so much clearer instantly.
For example let’s write the multi-query attention operation.
import torch as t
from einops import einsum
def multi_query_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor
_, _, head_dim = K.shape
attn_scores = einsum(Q, K,
"batch head seq1 head_dim, batch seq2 head_dim -> batch head seq1 seq2")
attn_matrix = t.softmax(attn_scores / head_dim ** 0.5)
out = einsum(attn_matrix, V,
"batch head seq1 seq2, batch seq2 head_dim -> batch head seq1 head_dim")
return out
One catch here is that we want to have the sequence length represented twice
for \(QK^T\) but we don’t want to sum over it. To solve this we give them two
different names like seq1
and seq2
The nice things about this are that we didn’t need to write a glossary for what
random variables b
or h
were supposed to mean, we can just read it off.
Also note that typically when computing attention, we need to calculate \(QK^T\). Here we didn’t need to worry about how exactly to take the transpose - we just give the dimension names and the correct transposes are done for the multiplication to make sense!
Einops also offers great functions for rearranging, reducing and repeating tensors which are also very useful.
That’s all! Just trying to make those inscrutable matrix multiplications, a little more scrutable. ￼
I feel like a fancy chef here. For our appetiser we have Matrix Multiplication Done Four Ways and so on… ↩
Given a task that we don’t know how to solve directly in code (e.g. recognising a cat or writing a unique sonnet), we often write programs which in turn, (via SGD), write second-order programs. These second-order programs (i.e. neural network weights) can solve the task, given lots of data.
Suppose we have some neural network weights which describe how to do a task. We might want to know how the network solved the task. This is useful either to (1) understand the algorithm better for ourselves or (2) check if the algorithm follows some guidelines we might like e.g. not being deceptive, not invoking harmful bias etc.
The field of Mechanistic Interpretability aims to do just that - given a neural
network^{1}, return a correct
, parsimonious
, faithful
and
human-understandable
explanation of the inner workings of the network when
solving a given task. This is analogous to the problem of
reverse engineering software from machine code
or the problem of a
neuroscientist trying to understand the human brain.
How are we to translate giant inscrutable matrices into neat explanations and high-level stories?
In order for a neural network to make some prediction, it uses internal neuron activations as “variables”. The neuron activations build up high-level, semantically rich concepts in later layers using lower-level concepts in earlier layers.
A dream of Mechanistic Interpretability would be this:
Suppose we had some idea that each neuron corresponded to a single feature. For example, we could point to one neuron and say “if that neuron activates (or “fires”) then the network is thinking about cats!”. Then we point to another and say “the network is thinking about the colour blue”. Now we could give a neural network some inputs, look at which internal neurons activate (or “fire”) and use this to piece together a story about how the network came up with its eventual prediction. This story would involve knowing the concepts (“features”) the network was “thinking” about together with the weights (“circuits”) which connect them.
This would be great! Unfortunately, there are a couple of problems here…
Firstly, neural networks are freaking huge. There can be literally billions of weights and activations relevant for processing a single sentence in a language model. So with the naive approach above, it would be an incredibly difficult practical undertaking to actually tell a good story about the network’s internal workings^{2}.
But, secondly, and more importantly, when we look at the neurons of a neural network we don’t see the concepts that it sees. We see a huge mess of concepts all enmeshed together because it’s more efficient for the network to process information in this way. Neurons that don’t activate on a single concept but instead activate on many distinct concepts are known as polysemantic neurons. It turns out that basically all neurons are highly polysemantic^{3}.
In essence, neural networks have lots of features, which are the
fundamental
units (“variables”) in neural networks. We might think of features
as directions in neuron space corresponding to the concepts. And neurons are
linear combinations of these features in a way that makes sense to the network
but looks very entangled to us - we can’t just read off the features from
looking at the activations.
So we’re given a network and we know that all the neurons are linear combinations of the underlying features but we don’t know what the features are. That is, we hypothesise that there is some linear map g from the feature space to neuron space. Generally, feature space is much bigger than neuron space. That is to say, there are more useful concepts in language than the number of neurons that a network has. So our map g is a very rectangular matrix: it takes in a large vector and outputs a smaller one with the number of neurons as the number of dimensions.
We want to recover the features. To do this we could try to find a linear function which can map from neuron space → feature space and acts as the inverse of g. We go to our Linear Algebra textbook (or ask ChatGPT) how to invert a long rectangular matrix and it says… oh wait, yeah this actually isn’t possible^{4}. A general linear map from feature space → neuron space loses information and so cannot be inverted - we can’t recover the features given only the neurons.
This seems bad but let’s soldier on. Instead of giving up, we instead ask, “okay well we can’t invert a general linear map g but what constraints could we put on g such that it might be invertible?” As it turns out, if most of the numbers in the matrix corresponding to g are 0 (that is if g is sufficiently sparse) then we can invert g.^{5}
Q: Hold on, is this reasonable? Why might we expect g to be (approximately) sparse?
In predicting the next token there will be some relevant features of the previous tokens which are useful. If the neural network has tens of thousands of features per layer (or perhaps even more), then we would expect some of them to be useful for each prediction. But if the prediction function uses all of the features it would be super complex; most features should be irrelevant for each prediction.
As an example consider if you’re deciding if a picture of an animal is a dog - you might ask “does it have 4 legs?” - 4 legged-ness is a useful feature. The texture of its fur is also relevant. The question “would a rider sit within or on top” is probably not relevant, though it might be relevant in other situations for example distinguishing a motorbike from a car. In this way, not all of the features are needed at once^{6}.
To recap, so far we’ve said:
- Language models use features in order to predict the next token.
- There are potentially a lot more features than there are neurons.
- If the linear map g: features → neurons was sparse then we might be able to find an inverse.
- Sparse maps are relatively good approximations to the real linear map g.
Sparse Dictionary Learning is a method which exploits these facts to numerically find the inverse of g. Intuitively what we have is a lookup table (or a “dictionary”) which tells us how much of each feature goes into each neuron. And if these features look monosemantic and human-understandable then we’re getting very close to the dream of Mechanistic Interpretability outlined above. We could run a model, read off the features it used for the prediction and build a story of how it works!
We’ll focus here on Anthropic’s set-up.
We start with a small 1-Layer transformer which has an embedding dimension of 128. Here the MLP hidden dimension is 512.^{7} The MLP contains:
We capture the MLP neuron activations and send those through our sparse autoencoder which has N dimensions for some N ≥ 512.
An AutoEncoder is a model which tries to reconstruct some data after putting it through a bottleneck. In traditional autoencoders, the bottleneck might be mapping to a smaller dimensional space or including noise that the representation should be robust to. AutoEncoders aim to recreate the original data as closely as possible despite the bottleneck. To achieve the reconstruction, we use a reconstruction loss which penalises outputs by how much they differ from the MLP activations (the inputs to the AutoEncoder).
In the Sparse AutoEncoder setting, our “bottleneck” is actually a higher
dimensional space than neuron space (N ≥ 512), but the constraint is that the
autoencoder features are sparse. That is, for any given set of MLP neuron
activations, only a small fraction of the features
should be activated.
In order to make the hidden feature activations sparse, we add an L1 loss over the feature activations to the reconstruction loss for the AutoEncoder’s loss function. Since the L1 loss gives the absolute value of the vector, minimising L1 loss pushes as many as possible of the feature activations towards zero (whilst still being able to reconstruct the MLP neurons to get low reconstruction loss).
To recap:
- The input of the AutoEncoder is the MLP activations.
The goal is for the output of the AutoEncoder to be as close to the input as possible - the reconstruction loss penalises outputs by how much they differ from the MLP activation inputs.
- The bottleneck is the sparsity in the hidden layer which is induced by pressure from the L1 loss to minimise feature activations.
In summary, the set-up Anthropic uses is:
The most surprising thing about this approach is that it works so well. Like really well.
There are, broadly, two ways to think about features:
Anthropic find many features which activate strongly in a specific context (say
Arabic script or DNA base pairs) and also (mostly) only activate when that
context is present. In other words, the features have high
precision and recall. This
suggests that these are ~monosemantic features! In terms of
Features as Results
, this captures what we would hope for - the features that
appear are mostly human-understandable.
The authors also find that once a feature is activated, the result is an
increase in plausible next tokens given the input. In particular, to demonstrate
this counterfactually, we can add a large amount of a given feature to the
neuron activations. Theoretically, this should “steer” the model to thinking
that context was present in the input, even if it wasn’t. This is a great test
for Features as Actions
.
Additionally, if we fully replace the MLP activations with the output of our
autoencoder^{8}, we get a model which explicitly uses our feature dictionary
instead of the learned MLP neurons. Here the resulting “dictionary model” is
able to get 95% of the performance of the regular model. The dictionary model
achieves this despite, in the case of large autoencoders, the features being
extremely sparse. This performance is a great sign for Features as Actions
; it
suggests that the sparse features capture most of the information that the model
is using for its prediction task! This also validates that our assumption that
features are approximately sparse seems to be a fairly good assumption^{9}.
They also note some other smaller results:
Certainly not. Although this approach is a breakthrough in approaching features and converting regular networks into less polysemantic ones, some problems remain:
Large models are still, well … large. Dictionary learning mitigates the problem since we don’t have to deal with polysemantic neurons anymore. But there’s still a lot that could happen between doing this on a small 1-Layer model and a large model. In particular, since there are many more features than neurons, Sparse AutoEncoders for large models could be absolutely gigantic and may take as much compute to train as the model’s pre-training. We will very likely need ways to improve the efficiency of Sparse AutoEncoder training.
In Machine Learning, as in Physics, More Is Different. That is, there may be qualitatively different behaviours for large models as compared to smaller ones. One clear way this could occur is when features are composed of many sub-features across different layers and form complex interactions. This is an open problem to be explored.
The Universality Hypothesis from Chris Olah states that sufficiently neural networks with different architectures and trained on different data will learn the same high-level features and concepts.
The authors show that when two models are trained with the same architecture but different random initialisations, they learn similar features. This is certainly a step towards universality but doesn’t show the whole thesis. A strong form of Universality would suggest that there are some high-level “natural” features/concepts which lots of different architectures for predictors (silicon and human brains) all converge on. We’re quite a way from showing this in the general case.
Though there are some proxy measures for interpretability, currently the best metric that we have is for a human to check and say “yes I can interpret this feature” or “no I can’t”. This seems hard to operationalise at scale as a concrete metric.
To bridge this gap large models such as GPT-4 and Claude can also help with the interpretability. In a process known as AutoInterpret, LLMs are given a prompt and how much each feature activates. They then attempt to interpret the feature. This works kinda okay at the moment but it feels like there should be a cleaner, more principled approach.
The authors show that by adding more of a given feature vector in activation space, you can influence a model’s behaviour. When, whether, and how steering works reliably and efficiently are questions that could all be useful. We might wish to steer models as a surgical needle to balance out the more coarse tool that is RLHF. In the future, this may also be useful to reduce harmful behaviour in increasingly powerful models.
As mentioned above, there would be an embarrassingly large number of features for a model like GPT-4 and so it looks like it will be difficult to create succinct compelling stories which involve so many moving parts. In some sense, this is the lowest level of interpretability. It’s analogous to trying to understand a very complex computer program by looking through it character by character, if the words were all jumbled up.
What we would like is some slightly higher level concepts composed of multiple features with which we can use to think. Splitting up the network into macro-modules rather than the micro-level features seems like a promising path forward.
Anthropic are very positive about this approach and finish their blog post with the line:
For the first time, we feel that the next primary obstacle to interpreting large language models is engineering rather than science.
There is some truth to how exciting this development is. We might ask whether the work ahead is purely scaling up. As we outlined in the problems for future work above, I do believe there are still some Science of Deep Learning problems which Mechanistic Interpretability can sink its teeth into. Only now, we also have a new tool which is incredibly powerful to help us along the way.
In light of the other problems that still remain to be solved, we might add the final sentences of Turing’s 1950 paper, as an addendum:
We can only see a short distance ahead, but we can see plenty there that needs to be done.
Thanks to Derik and Joe for comments on a draft of this post.
With both its weights and its activations on a series of input examples say ↩
Of course, if it’s just a practical undertaking perhaps we would grit our teeth and try to do this - it appears we at least have the tools to give it a shot, even if it’s painfully slow. We have completed huge practical undertakings before as a scientific community e.g. deciphering the human genome or getting man to the moon. As we will see there is another concern as well. ↩
One theory of exactly how that might come about is found in the Superposition Hypothesis. ↩
Thanks to Robert Huben for this useful framing ↩
The proof of this and the convergence properties are analogous to how you can use fewer data points for linear regression if you know that the linear map you’re trying to find is sparse e.g. with Lasso methods for sparse linear regression. For this to work precisely, we add a bias and a ReLU non-linearity. ↩
This is similar to the intuition of the MoEification paper - MLPs naturally learn some sparse/modular structure, which we might hope to exploit. ↩
With the convention from GPT-2 that MLP_dim = 4 * embedding_dim ↩
which, we recall, is trying to reconstruct the MLP activations through the sparse bottleneck ↩
To the extent that we don’t get 100% of the performance, there are a few hypotheses. Firstly, we might not have the optimal autoencoder architecture yet or the autoencoder might not be fully trained enough to saturation. Secondly, altering the l1loss coefficient hyperparameter adjusts how sparse we want to make our features and there may be some tuning to do there. Thirdly, the network might just not _fully sparse, this seems likely - there are some early results showing that as the size of the model increases (from the toy model we have to a large frontier model), we might expect more sparsity - which suggests that Dictionary Learning may get better with scale. The later Cookbook Features paper also suggests this. ↩
Foundation models aim to solve a wide range of tasks. In the days of yore, we would build a supervised model for every individual use case; foundation models promise a single unified solution.
There are challenges with this however. When two tasks need different skills, trying to learn both can make you learn neither as well as if you had focused on one^{1}. Storing information for many tasks can also be a challenge, even for large models.
Moreover we might wonder if it make sense to use the same parameters for computing the answer to a logic puzzle and for finding the perfect adjective to describe the love interest in a romance fanfic.
We would like our models to have modular functions. We could then select and even combine abilities when needed.
Scaling up models offers various advantages. There are three main quantities to
scale: the number of model parameters
, the amount of data
and the amount of
compute
applied at train time. With regular transformers, to scale up the
number of parameters, we must likewise scale the amount of compute applied.
Intuitively more parameters mean more
knowledge
, and more compute represents additionalintelligence
^{2}.
There are some use cases where having more knowledge can be traded off with being more cognitively able. For example, you may choose to memorise rather than re-derive the laws of physics to use them in a specific problem. Similarly we can trade off the opposite way as well - if you know you’ll have access to a textbook or Wikipedia then you might not want to memorise certain historical facts. All you need to know is when and how to look up the facts you need.
So, dependent on whether we need more knowledge or more cognitive ability, we also want to scale parameters and compute separately^{3}.
In a vanilla transformer, each Transformer Block contains an attention layer for
communication
between tokens and an MLP layer for computation
within
tokens. The MLP layer contains most of the parameters of a large transformer and
transforms the individual tokens.
In Sparse Mixture of Experts (MoEs), we
swap out the MLP layers
of the vanilla transformer for an Expert Layer
. The
Expert Layer is made up of multiple MLPs called “Experts”. For each input we
select one expert to send that input to. In this way, each token has different
parameters applied to it. A dynamic routing mechanism decides how to map tokens
to Experts^{4}.
Sparse MoEs solve the problems we noted earlier:
Imagine you’re feeling fatigued and you have no idea what’s causing this. Suppose the problem is with your eyes but you don’t know this yet. Since your friend is a cardiologist (doctor specialising in the heart), you ask them for advice, which they freely give. You might ask yourself if you should follow their advice blindly or if you should:
Approach 1: Get a second opinion from another cardiologist.
Averaging over multiple doctors who were trained in the same way increases robustness by reducing variance (maybe the first doctor was tired that day or something). But it doesn’t help with bias ^{7} - all the cardiologists are likely to be wrong in the same way, if they are wrong at all.
Approach 2: Go to a generalist doctor that has no specialism.
It’s not clear whether this is better than asking another cardiologist. Sure they might have different knowledge to the cardiologist which might be useful if your problem isn’t about the heart. But there’s an awful lot of medical knowledge out there and we can’t reasonably expect this one generalist to know everything about all of them. They probably have cursory knowledge at best. We need someone who specialises in the area that we’re struggling with. Problem is we don’t know which area of specialism we need!
Approach 3: Ask multiple doctors who all specialise in different areas and do the thing most of them suggest.
This is much better. If you have a problem with your eyes, you know that the eye doctor is being consulted so you have a much better chance of getting the right treatment. But there are downsides here. Most notably, asking multiple doctors is probably pretty inefficient. Now we have to see 50 specialists for every problem even though most of them have no idea about our problem. What we would prefer is to know which one specialist (or possibly couple of specialists) we should see and only get advice from them.
Approach 4: Go to your GP, tell them about your ailment and ask them which doctor you should go and see.
Here we get the benefits of getting advice from the most relevant specialised doctor without having to ask every other doctor. This is both more accurate and time-efficient.
In approach 4, the GP is the routing function. They know the strengths of the different doctors and send you to one of them depending on your problem.
The Doctors are the Experts. We allow them to specialise knowing that the GP can route us to the correct doctor for our problem.
The GP-doctor system is exactly a Mixture of Experts layer.
Viewed this way we see that Mixture of Expert models will be effective whenever we want a model to have access to large amounts of information - more than a single Expert could hope to learn alone. Another use case is when our task can be decomposed into one of a number of tasks.
In general we might imagine MoEs which when faced with more difficult problems can send the input to a more powerful expert which has access to more resources. This starts to move us increasingly towards Adaptive Computation.
This phenomena is known as negative interference in learning. Jack of All Trades, Master of None. For other tasks we can see positive interference however, also known as Transfer Learning. ↩
For some vague definitions of “intelligence” and “knowledge”. This intuition is courtesy of Noam Shazeer. ↩
In reality both knowledge and cognitive ability are hard to separate this cleanly but hopefully the intuition still remains useful. ↩
The experts “compete” to process the tokens and as in Natural Selection and Economics, competition for niches makes them specialise. ↩
In actuality Expert might not necessarily specialise strictly by task. It might be beneficial for an expert to specialise in syntactic rather than semantic features or to combine two tasks which are different enough to not inference with each other. ↩
This approach also has good biological precedent. Humans don’t use every part of their brain for every stimulus they receive - when they receive, for example a visual stimuli, they use only their visual cortex to process it. ↩
In the statistical sense ↩
In traditional Sparse MoEs, we swap out
the MLP layers
of the vanilla transformer for an Expert Layer
. The Expert
Layer is made up of multiple MLPs referred to as Experts. For each input one
expert is selected to send that input to. A dynamic routing mechanism decides
how to map tokens to Experts. Importantly, though this is less mentioned, MoEs
are more modular and hence more naturally interpretable than vanilla
transformers.
The Soft MoE paradigm was introduced by Google researchers in the paper From Sparse To Soft Mixtures of Experts. Unlike Sparse MoEs, Soft MoEs don’t send a subset of the input tokens to experts. Instead, each expert receives a linear combination of all the input tokens. The weights for these combinations are determined by the same dynamic routing mechanism as in Sparse MoEs.
The discrete routing that makes Sparse MoEs so effective also makes them not inherently fully differentiable and can cause training issues. The Soft MoE approach solves these issues, are better suited to GPU hardware and in general outperform Sparse MoEs.
The paper abstract reads:
Sparse mixture of expert architectures (MoEs) scale model capacity without large increases in training or inference costs. Despite their success, MoEs suffer from a number of issues: training instability, token dropping, inability to scale the number of experts, or ineffective finetuning. In this work, we propose Soft MoE, a fully-differentiable sparse Transformer that addresses these challenges, while maintaining the benefits of MoEs. Soft MoE performs an implicit soft assignment by passing different weighted combinations of all input tokens to each expert. As in other MoE works, experts in Soft MoE only process a subset of the (combined) tokens, enabling larger model capacity at lower inference cost. In the context of visual recognition, Soft MoE greatly outperforms standard Transformers (ViTs) and popular MoE variants (Tokens Choice and Experts Choice). For example, Soft MoE-Base/16 requires 10.5× lower inference cost (5.7× lower wall-clock time) than ViT-Huge/14 while matching its performance after similar training. Soft MoE also scales well: Soft MoE Huge/14 with 128 experts in 16 MoE layers has over 40× more parameters than ViT Huge/14, while inference time cost grows by only 2%, and it performs substantially better.
I recently gave a talk at EleutherAI, the open-source AI research lab, about Soft MoEs.
You can watch the talk back on YouTube here ^{2} or view the slides here.
I’m very excited about research ideas working on expanding the SoftMoE paradigm
to autoregressive (GPT-style) models, which is currently an open problem
described in the above talk. Feel free to reach out if you’re interested in or
are currently researching in this area.
For more details on MoE models see the Awesome Adaptive Computation repo. ↩
Unfortunately the video’s audio quality isn’t as great as it could be, I may look at cleaning this up. ↩
In the literature and the public conversation around Natural Language Processing, lots has been made of the results of scaling up data, compute and model size. For example we have the original and updated transformer scaling laws.
One sometimes overlooked point is the vital role of engineering breakthroughs in enabling large models to be trained and served on current hardware.
This post is about the engineering tricks that bring the research to life.
Note: This post assumes some basic familiarity with PyTorch/Tensorflow and transformers. If you’ve never used these before check out the PyTorch docs and the Illustrated Transformer. Some background on backpropagation works will also be useful - check out this video if you want a refresher!
DeepSpeed has four main use cases: enabling large training runs, decreasing inference latency, model compression and enabling ML science.
This post covers training optimizations.
Training large models (e.g. LLMs) on huge datasets can be can be prohibitively slow, expensive, or even impossible with available hardware.
In particular, very large models generally do not fit into the memory of a single GPU/TPU node. Compared to CPUs, GPUs are generally higher throughput but lower memory capacity. (A typical GPU may have 32GB memory versus 1TB+ for CPUs).
Our aims are:
DeepSpeed reduces compute and time to train by >100x for large models.
If you just want to see how to implement DeepSpeed in your code, see the Using DeepSpeed section below.
Without any data parallelism, we get this sorry sight:
We’ve spent a lot of money on GPU cores for them all to sit there idle apart from one! Unless you’re single-handedly trying to prop up the NVIDIA share price, this is a terrible idea!
One thing that we might try is splitting up the data, parallelising across devices. Here we copy the entire model onto each worker, each of which process different subsets of the training dataset.
Each device compute its own gradients and then we average out the gradients
across all the nodes to update our parameters with all_reduce
. This approach
is pretty straightforward to implement and works for any model type.
We’ve turned more GPUs into more speed - great!
In addition we also increase effective batch size, reducing costly parameter updates. Since with larger batch sizes there is more signal in each gradient update, this also improves convergence (up to a point).
Unfortunately the memory bottleneck still remains. For Data Parallelism to work, the entire model has to fit on every device, which just isn’t going to happen for large models.
Another thing we could try is splitting up the computation of the model itself, putting different layers (transformer blocks) on different devices. With this model parallelism approach we aren’t limited by the size of a memory of a single GPU, but instead by all the GPUs that we have.
However two problems remain. Firstly how to split up a model efficiently is very dependant on the specific model architecture (for example the number of layers and attention heads). And secondly communicating between nodes now bottlenecks training.
Since each layer requires the input to the previous layer in each pass, workers spend most of their time waiting. What a waste of GPU time! Here it looks like the model takes the same amount of time as if we had a GPU to fit it on but it’s even worse. The communication overhead of getting data between nodes makes it even slower than a single GPU.
Can we do better than this?
Data Parallelism gave speedups but couldn’t handle models too large for a single machine. Model Parallelism allowed us to train large models but it’s slow.
We really want a marriage of the ideas of both data and model parallelism - speed and scale together.
We don’t always get what we want, but in this case we do. With DeepSpeed, Microsoft packaged up a bag of tricks to allow ML engineers to train larger models more efficiently. All in, DeepSpeed enables >100x lower training time and cost with minimal code changes - just 4 changed lines of PyTorch code. Let’s walk through how.
One Seven Weird Tricks to Train Large Models:
Ordinarily mathematical operations are performed with 32 bit floats (fp32). Using half precision (fp16) vs full precision (fp32) halves memory and speeds up computation.
We forward/backward pass in fp16 for speed, keeping copies of fp32 optimizer states (momentum, first order gradient etc.) for accuracy. The high precision fp32 maintains the high dynamic range so that we can still represent very slight updates.
A simple training loop might contain something like:
for i, batch in enumerate(train_loader):
for j, minibatch in enumerate(batch):
loss = model(minibatch)
local_gradients = gradients(loss / batch_size)
average_gradients = distributed.all_reduce(local_gradients) # reduce INSIDE inner loop
optimizer.step(average_gradients)
Note here that within every loop we’re calculating not only the local gradients but also synchronizing gradients which requires communicating with all the other nodes.
Delaying synchronization improves throughput e.g:
for i, batch in enumerate(train_loader):
for j, minibatch in enumerate(batch):
loss = model(minibatch)
local_gradients = gradients(loss / batch_size)
average_gradients = distributed.all_reduce(local_gradients) # reduce OUTSIDE inner loop
optimizer.step(average_gradients)
Suppose we have a GPU with 50GB of memory and our model weights are 10GB of memory. That’s all great right?
For inference we feed in our input data and get out activations at each step. Then once we pass each layer, we can throw away activations from prior layers. Our model fits on the single GPU.
For training however, it’s a different story. Each GPU needs its intermediate activations, gradients and the fp32 optimiser states for backpropagation. Pretty soon we’re overflowing the GPU with our model’s memory footprint 😞
The biggest memory drain on our memory is the optimisation states.
We know that we’re going to need to get multiple GPUs and do some model parallelisation here. Eventually we want to partition the whole model but a good first move would be to at least remove optimisation state redundancy.
For ZeRO stage 1, in the backward pass, each device calculates the (first order)
gradients for the final section of the model. The final device gathers
all
these gradients, averages them and then computes the Adam optimised gradient
with the optimisation states. It then broadcasts
back the new parameter states
for the final section of the model to all devices. Then the penultimate device
will do the same and so on until we reach the first device.
We can think of this as a 5 step process:
reduce
broadcasts
the new gradients to all of the nodes.ZeRO stage 1 typically reduces our memory footprint by ~4x.
🔄 Fun Fact: The name DeepSpeed is a palindrome! How cute 🤗
We can take the partitioning idea further and do it for parameters and gradients as well as optimisation states.
broadcasts
the parameters for the first section of the model.broadcasts
its section gradients.reduce
), calculates gradient update with optimiser and then broadcasts
the results, which can be used for the next section.If we have N
cores, we now have an N
x memory footprint reduction from ZeRO.
That was the most complex part so feel free to check out these resources to make sure you understand what’s going on:
It’s all downhill from here!
Overall, ZeRO removes the redundancy across data parallel process by partitioning optimizer states, gradients and parameters across nodes. Look at how much memory footprint we’ve saved!
One surprising thing about this approach is that it scales superlinearly. That is, when we double the number of GPUs that we’re using, we more than double the throughput of the system! In splitting up the model across more GPUs, we leave more space per node for activations which allows for higher batch sizes.
Most of the operations in a large ML model are matrix multiplications followed by non-linearities. Matrix multiplication can be thought of as dot products between pairs of matrix rows and columns. So we can compute independent dot products on different GPUs and then combine the results afterwards.
Another way to think about this is that if we want to parallelise matrix multiplication across GPUs, we can slice up huge tensors into smaller ones and then combine the results at the end.
For matrices \(X = \begin{bmatrix} X_1 & X_2 \end{bmatrix}\) and \(A = \begin{bmatrix} A_1 \\ A_2 \end{bmatrix}\), we note that:
\[XA = \begin{bmatrix} X_1 & X_2 \end{bmatrix} \begin{bmatrix} A_1 \\ A_2 \end{bmatrix}\]For example:
However if there is a non-linear map after the M e.g. if \(Y = \text{ReLU}(XA)\), this slicing isn’t going to work. \(\text{ReLU}(X_1A_1 + X_2A_2) \neq \text{ReLU}(X_1A_1) + \text{ReLU}(X_2A_2)\) in general by non-linearity. So we should instead split up X by columns and duplicate M across both nodes such that we have:
\[Y = [Y_1, Y_2] = [\text{ReLU}(X A_1), \text{ReLU}(X A_2)] = XA\]For example:
Note: normally we think of A acting on X by left multiplication. In this case X is our data and A is the weights which we want to parallelise. Through taking transposes we can swap the order of the geometric interpretation so we can think of the above as linear map A acting on our data X and still retain the slicing.
In our description of ZeRO each core cached (held in memory) the activations for it’s part of the model.
Suppose we had extremely limited memory but were flush with compute. An alternative approach to storing all the activations would be to simply recompute them when we need in the backward pass. We can always recompute the activations by running the same input data through a forward pass.
This recomputing approach saves lots of memory but is quite compute wasteful,
incurring m
extra forward passes for an m-layer
transformer.
A middle ground approach to trading off compute and memory is gradient checkpointing (sometimes known as activation checkpointing). Here we store some intermediate activations with \(\sqrt m\) of the memory for the cost of one forward pass.
While not strictly causing any code optimisations, DeepSpeed provides developer friendly features like convenient profiling and monitoring to track latency and performance. We also have model checkpointing so you can recover a model from different points in training. Developer happiness matters almost as much as loss!
Check out the docs for more info!
Animated Video from Microsoft: warning, it’s a little slow.
The full DeepSpeed library, with all the hardware level optimisations, is open-sourced. See the core library, the docs and examples.
For an annotated and easier to follow implementation see Lab ML’s version.
DeepSpeed integrates with PyTorch and TensorFlow to optimize training.
In PyTorch we only need to change 4 lines of code to apply DeepSpeed such that our code is optimised for training on a single GPU machine, a single machine with multiple GPUs, or on multiple machines in a distributed fashion.
First we swap out:
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
with initialising DeepSpeed by writing:
ds_config = {
"train_micro_batch_size_per_gpu": batch_size,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 1,
"offload_optimizer": {
"device": "cpu"
}
}
}
model_engine, *_ = initialize(model=model_architecture,
model_parameters=params,
config = ds_config)
Then in our training loop we change out the original PyTorch…
for step, batch in enumerate(data_loader):
# Calculate loss using model e.g.
output = model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
for:
for step, batch in enumerate(data_loader):
# Forward propagation method to get loss
loss = ...
# Runs backpropagation
model_engine.backward(loss)
# Weights update
model_engine.step()
That’s all it takes! In addition, DeepSpeed’s backend has also been integrated with HuggingFace via the Accelerate library.
There’s a lot of clever improvements that go into the special sauce for training large models. And for users, with just a few simple code changes, DeepSpeed works its magic to unleash the power of all your hardware for fast, efficient model training.
Happy training!
]]>