/
Building a Large Language Model from scratch, part 2
This blog post is the next part of a series where I work through Sebastian Rauschka’s Build a Large Language Model from Scratch, which is a step-by-step guide to… exactly what it sounds like. Here, I walk through the most complex chapter in the book, which incrementally builds an intuition for - and example implementation of - the self-attention mechanism at the heart of a modern LLM’s transformer architecture.
It’s worth mentioning that this guide focuses on OpenAI’s GPT-2 as its reference for GPT-style architecture, so while it may be missing cutting-edge refinements and optimizations in more modern LLMs, it’s a very good illustration of the general principles behind these models.
The previous installment is here.
Chapter 3: Coding attention mechanisms
In short, self-attention is the innovation that allows the model’s representation of a token (e.g. word) to efficiently incorporate information about what other words in the sentence it relates to, and how. Rauschka structures this chapter as a progression through four stepping stones of understanding the self-attention mechanism, each one adding more complexity:
- Simplified self-attention
- Self-attention (with trainable weights)
- Causal attention (no future token information leakage)
- Multi-head attention
Framing the problem
The chapter begins with, helpfully, some background on why self-attention is even such a necessary advance for LLMs. Rauschka phrases this as “the problem with modeling long sequences”.
Early machine translation efforts that used deep learning techniques ran into a fundamental challenge: grammatical differences between languages mean that it’s impossible to translate sentences word-for-word without knowledge of the larger context. Words in the source language that perform a certain role in the sentence may have their counterparts in a totally different part of the target sentence; their counterparts might be a multi-word phrase (not necessarily consecutive, either; think English “don’t” vs. the French “ne…pas sandwich”); or the counterpart may not even exist explicitly in the target!
The high-level solution to this, which predates but also includes LLMs, is for the architecture to include an encoder module that projects text into latent space, then a decoder module that translates it into destination tokens. That is, assume that we can construct a continuous space that represents the “meaning” of a token or set of tokens irrespective of the original language. Then the task is to find the optimal operations that transform tokens to and from this other representation. This should be very reminiscent of the token embedding discussion from the previous post - it’s the same principle! The difference here is that the latent space representation should not just be for a single token in isolation, but include some information about other relevant tokens. This context information is often focused on, but not limited to, surrounding tokens. The problem of pronoun reference resolution is a concrete example of what this approach is trying to solve.
Prior to the development of self-attention, this was notoriously difficult. The previous state-of-the-art for machine translation was recurrent neural networks (RNNs), which track hidden state (memory) at each next-word-prediction iteration. However, classical RNNs can’t directly access past iterations’ hidden states, which limits how richly token context can be represented. As an RNN iterates over a sequence of tokens, the contribution of distant tokens gets lost as the information is compressed into the single hidden state. Furthermore, there isn’t a great way to feed information from future tokens in the sequence to update the context for previous tokens. This isn’t so relevant for GPT-style LLMs, but is more important to NLP problems like solving for missing words within a sentence, which models like BERT would eventually address also with transformer architecture.
The breakthrough solution to these limitations was the transformer self-attention mechanism, introduced to the public in 2017 with the paper “Attention Is All You Need”. Interestingly, Rauschka notes that an earlier innovation from 2014, the Bahdanau attention mechanism for RNNs, directly inspired it - I had totally glossed over the passing reference in “Attention Is All You Need” but it is definitely in the bibliography. Self-attention evaluates, for each input sequence position, the relevance of all other positions when encoding it. The “attention” part of the name refers to how different weights “attend” to different parts of the input sequence, while “self”-attention specifies that this applies within just one sequence. (Contrast this with relating elements in an input sequence explicitly to elements in a target sequence, as was done in other historical approaches to sequence-to-sequence modeling.) The key idea is, for each input element in the sequence, to compute a context vector that is a weighted combination of all other elements in the sequence, and that the weights are specific to that element. This context vector can be thought of as an “enriched embedding vector”: for NLP problems, it augments the token-specific embedding with latent space information about the surrounding tokens with respect to that one.
It’s worth noting that in tracking every element’s relevance to every other element, this incurs memory requirements (e.g. weight parameters, state during inference) that are quadratic with respect to the maximum sequence length. This helps explain why large language models are so large!
A simple self-attention mechanism without trainable weights
This is the minimal illustrative example of how self-attention works. It’s not directly usable in an LLM at all, but its purpose is rather to teach the key moving parts of self-attention as simply as possible.
Assume a sequence (say, of token embeddings) with elements . is the input element, vector-valued (because… it’s an embedding). Then let be the context vector for each query element . That is, will represent the embedding for element , augmented with the information from all other elements in the sequence. This is the desired output from the self-attention mechanism. Note that must have the same shape as .
The context vector is computed from an intermediate quantity, the attention scores. This roughly represents the relevance of each sequence element to the query element. In the simplified model, there are no tuned weights for computing the attention score, so assume that any weights are equivalent to 1. The attention scores then are just the dot product of the query token vector with all the input token vectors. Then, normalize these across all the input token vectors. In practice, normalization is done with softmax, so the attention scores sum to 1 and are always positive. Raushka terms the raw dot products “attention scores” and the normalized ones “attention weights”. Expressed another way:
- Attention scores: query inputs, or
- Attention weights: normalized scores vector, or
Note that the vector of attention weights has dimension , since the attention score of each element with respect to the query element is a scalar (because dot product), and there are of them.
The remaining step is to compute the final context vector. This is the product of the attention score and the sequence input embeddings:
- Context vector: attention inputs, or
Those are for a single query element among all the input elements in the sequence. Generalized for all elements (token embeddings) in the sequence simultaneously:
The input token embedding vector also has alternative names depending on the role it fills in the attention mechanism.
where . These are, respectively, the query, the key, and the value vectors.
Why use these terms, which come from information retrieval jargon? The context vector can be thought of as answering the question, “What are the values (tokens) that are most relevant to the query?” Relevance is determined by comparing the query to a key token, through the attention score dot product, as if we were looking up the value associated with that key. In this case, the key and the value are identical, i.e. !
Self-attention with trainable weights
The next step is to introduce nontrivial weights into the attention score calculations. Rauschka also refers to this “scaled dot-product attention”.
Now define weight matrices . These adjust the matrix products above, and importantly, these are trainable parameters. These should be initialized randomly prior to training.
- are the query weights, and apply to the query elements (the first in the attention score formula).
- are the key weights, and apply to the key elements (the second in the attention score formula).
- are the value weights, and apply to the value elements (the in the context vector formula).
Whenever evaluating the context vector, instead first compute the intermediate matrices
- “Keys”:
- “Values”:
- “Queries”:
(Note that these are themselves embeddings in a new space! The embedding is based on the dimension of the weights matrices, which is another hyperparameter. The smaller the dimension, the more compressed information will be.)
This leads to revised formulas for the attention scores/weights:
- where is the embedding dimension of the keys. The rescaling in the normalization is in order to avoid very small gradients during backpropagation.
Again, the motivation here is that during model training, the weights now can be optimized to improve model task performance.
Causal attention
Often we want the LLM to only consider preceding/current (known) tokens in the sequence. This avoids leakage of “future” information when predicting the next token in the sequence, as with GPTs. The next step in developing the self-attention mechanism is to add this requirement.
The general principle here is to create a lower-triangular mask in the shape of the attention weights (normalized to 1), and multiply by them. Then, renormalize the post-mask row weights to sum to 1 again. With softmax normalization, this can be done more efficiently by masking with before normalizing to attention weights just once. The consequence is that “future” information is ignored in training sequences.
An implementation note: the example code in the book uses nn.module.register_buffer for the causal mask. I was initially confused why a buffer was specifically chosen instead of a new parameter. This is because it is equivalent to adding a tensor parameter to the module which specifically is not updated or has its gradient computed. See this thread for a full discussion on registering buffers vs. parameters.
The text also discusses applying dropout with masking at this stage. Empirically, applying this (only during training!) appears to help with robustness to overfitting masking. Rauschka notes that this is typically done after computing the (pre-normalization) attention weights with a dropout rate below 20%. When using torch.nn.Dropout’s implementation, this will also rescale up values of unmasked elements to compensate; this is generally helpful for stability / vanishing gradients.
This section also brings up the topic of batching for self-attention. When evaluating over a batch, the dimension of the input sequence matrix becomes where is batch size, is max token length, and is the input’s embedding dimension. The same formulas apply, but the implementation needs to be conscious of the first, batch, dimension.
Multi-head attention
Finally, the text arrives at the full self-attention mechanism from the original transformer paper. The causal attention variant above yields one context vector per query token, which is akin to identifing just one way in which other tokens are relevant to the query across the sequence. However, there are many ways in which the rest of the sequence might be relevant to a particular word simultaneously! For example, in the sentence “The old dog chased the quick fox which bit him”, the token for “him” has many different kinds of relationships with the other tokens:
- Pronoun resolution: “him” refers back to “dog”.
- Adjectives: similarly, “old” describes “him” too.
- Subject-verb-object: “him” is the object of the verb “bit”, and is related to “fox” which did the action.
The key idea behind multi-head attention is that just one context vector may not be capable of faithfully representing all of these different relationships. Instead, the mechanism should allow for multiple context vectors! Each pathway for computing a context vector, as above, is termed a head.
A simple but inefficient approach is stacking multiple single-head attention layers. This can be done with multiple independent matrices computed in series; then concatenate the resulting context vectors. However, parallelizing this by simultaneously evaluating all heads is much more efficient. Recall that the weights matrices for a single head have a projection dimension that’s a hyperparameter. Now, the projection dimension of the multi-head weights matrices should be divisible by the number of heads. In other words, the stored weights matrices as-implemented are the concatenated per-head weights matrices. During the context computation, then, the stored weights matrices are reshaped (.view(), .transpose()) to perform the matrix multiplication in parallel across heads. Finally, an addtional output projection linear layer often is included at the end. Allowing learnable weights among heads apparently can improve model task performance in practice.
Parting thoughts
The incremental approach in this chapter, from simpler to richer toy models including numerical examples, was especially effective in helping me understand this fairly complex topic. As I’ve noted before, I think this is a particular strength of Rauschka’s book. This is a contrast with The Illustrated Transformer, which is also an excellent resource, but takes a different strategy: it represents the original transformer architecture faithfully but walks through the transformer “stack” in linear order. A resource that combined the incremental approach from Rauschka with the excellent visualizations in Alammar’s online guide would be the best of both worlds!
I definitely came away from this chapter with a much stronger understanding of why self-attention works the way it does, and working through the code examples was a critical part of that. It confirms that for me at least, there really is no substitute for hands-on learning through doing.
I still have four more chapters to cover in this blog post series. In the next post, I’ll share my summary of, and reflections on, the material covering up to pretraining a complete GPT model locally!