Overview
A transformer is a neural network architecture capable of efficiently forming relationships in sequential data. It does this using a mechanism called attention.
When processing some token, attention allows the model to consider how that token is affected by all other tokens in the model's context.
Language tasks are inherently sequential, yet often depend on relationships between distant pieces of text, making them a strong example of why attention is powerful.
Consider the following text streams:
"The capital of France is"
In this simple example, attention increases the weight of capital, France, and is to provide a higher likelihood of Paris being generated.
"def scale(values, factor):
return [v * factor for v in"
In code generation, attention helps connect back to earlier code context. In this case, the relationship between v, the loop structure, and the input values.
Before Attention
Before attention, models often relied on long chains of memory to pass information from earlier tokens to later ones. This is the core idea behind recurrent neural networks, where information is carried forward step by step through the sequence using hidden layers (h).
It can be seen that in the unrolled recurrent neural network architecture, information must travel through potentially many hidden states which can result in dilution.
An example that highlights the problem more clearly:
The capital of France, which I visited after spending three weeks traveling through Europe and reading about its history, is
In this case, the model needs to preserve the relevance of France across many intermediate hidden states.
As information is passed step by step, that signal weakens, making it difficult to encode the importance of that token. In turn, the model's prediction could become inaccurate - or, not as accurate as it would be if it maintained that relevance.
A related issue is that as sequential depth increases, the risk of vanishing gradients also grows, making training more difficult.
When a recurrent neural network is unrolled, as shown in the diagram above, these two issues become more apparent. It can be seen that the hidden state at a timestep t h_t is a function of all previous states.
Expanding the recurrence itself shows the same hidden-to-hidden matrix \(W_{hh}\) is reused at every timestep.
During the forward pass, information from the initial hidden state h_0 can become progressively diluted as it is transformed through many repeated applications of the recurrence.
During the backwards pass, this same repeated matrix can make training less effective. Since \(W_{hh}\) is reused at every timestep, the backward pass must repeatedly differentiate through it, which can cause the gradient signal to become very small.
Phrased a bit differently: because the same \(W_{hh}\) appears at every step of the recurrence, it also appears repeatedly in the backward pass. If the factors \(\phi'(a_i)W_{hh}\) are often smaller than 1 in magnitude, their product shrinks rapidly, so early timesteps contribute very little to \(\frac{\partial L}{\partial W_{hh}}\). That is the essence of the vanishing gradient problem.
After backpropagation computes the gradient of the loss function with respect to each hidden layer h(t), each entry of \(W_{hh}\) is updated using this rule. If early timesteps contribute only very small gradients, then the total gradient is dominated more by recent timesteps than by long-range dependencies. In that case, information from earlier parts of the sequence has very little effect on the update, so those contributions are effectively lost.
With Attention
With attention, the current position does not need to rely on many sequential recurrent steps and can instead directly read from earlier relevant positions.
So if token \(t\) needs information from token \(j\), it does not need to preserve that signal through \(t-j\) recurrent transitions. It can form a direct weighted connection to that earlier token.
The attention mechanism leads to the following improvements: (1) the forward path is shorter and (2) the backward gradient path is shorter. As a result of this, deep dependencies are easier to learn and can more effectively be used to generate more probable output. This is the key breakthrough that introduced the transformer architecture as an effective means of sequence modeling in Attention Is All You Need.[1]
Transformer Architecture
The transformer architecture replaces recurrence with stacked attention and feedforward blocks, allowing each token to gather context from the entire sequence in parallel.
Encoder models receive an input and build a representation of it. They are useful when the goal is to understand the input, such as in sentence classification or entity recognition. BERT is a common example.
Decoder models generate a sequence one token at a time using prior context. They are useful for generative tasks such as text generation. GPT is a common example.
Encoder-Decoder models use the encoder to understand the input and the decoder to generate the output. They are useful for tasks such as translation and summarization. T5 and the original transformer are common examples.
Decoder Flow
1. Token Embeddings and Positions
Raw tokens are encoded into some N-dimensional vector. Positional information is encoded as well.
2. Masked Self-Attention
After getting encoded into some high-dimensional vector representation of real numbers, the token being processed heads to the attention layers.
There are 3 different projected representations used in attention:
- Query (Q) which determines what kind of information the current position wants to retrieve
- Key (K) which is compared against queries to decide how relevant each token is
- Value (V) which is the information mixed into the output according to those attention weights
For a decoder architecture, the current token being processed should only have context of all previously processed tokens to that point. Therefore future tokens are masked with negative infinity which become zero following softmax.
Now the updated France vector carries not just "France-ness", but a more contextual representation relating France to the concept of a capital in a city. This helps the decoder provide a more accurate prediction for the next appropriate token.
In another example:
"The snack-seeking, scruffy dog ran to the"
Attention helps the model treat dog in the context of the full phrase, not just as an isolated word.
Because of that, the continuation can better reflect the meaning of the sentence as a whole, predicting some place a snack-seeking scruffy dog would run to...
3. Feedforward Layer and Next-Token Prediction
The transformer also has feedforward elements. These are no different than a traditional feedforward neural network. They have their own weights and discover some kind of other useful meaning.
By the end of these serial layers, the transformer takes the last hidden vector h and
multiplies it by a vocabulary matrix. This produces one raw score for every possible token in the
vocabulary.
These are not probabilities yet. They are just raw scores, so for a phrase like
The capital of France is, the model might assign larger values to more plausible next
tokens.
"Paris" -> 12.3
"London" -> 4.8
"city" -> 3.1
"banana" -> -2.0
Softmax then turns those logits into probabilities:
After softmax, the model now has a probability distribution over the vocabulary:
"Paris" -> 0.82
"London" -> 0.06
"city" -> 0.03
...
The decoder can then choose the next token from this distribution. It may do this using many heuristics on top of the probability. For example, it may manage repetition by adding presence or frequency penalties that help keep the model from repeating itself. Likewise, limits on total token output length could exist that steer the model away from generating massive streams.
Following the generation of the next word, that word is added back into the model. Effectively, the model shifts forward, incorporating that token into the context for its next step. At this point, the model is already pre-filled with context, so re-feeding the entire input from scratch would be inefficient. To avoid this, caching is often used. The K-V vectors are kept in memory, which significantly reduces computation and speeds up future passes through the model.
The Trade With More Context
More context allows the model to look back at more information that's hopefully relevant. That information can provide deeper insight for more accurate generation. However, as the context grows, attention layers become more expensive in both compute and memory. Q, K, and V grow with n, while the attention value matrix grows with n^2.
With K-V caching, this directly impacts the amount of memory needed during forward passes of the model.
Tokenization
Breaking Text into Tokens (todo)
Token IDs and Embeddings
After text is split into tokens, each token is mapped to an integer ID. That ID is then used as an index into a learned embedding table, returning the token's vector representation.
tokens -> ["small", "dog"]
token IDs -> [1834, 481]
lookup -> [E[1834], E[481]]
These vectors are learned numerical representations. Their direction in the embedding space can encode multiple overlapping pieces of meaning at once.
The embedding lookup gives the model an initial vector. Attention layers described earlier can update that vector using context and shift it toward a more specific meaning.
Optimization (todo)
Training-Time Optimization
talk about parallelism and memory/backprop, attention cost during training, encoding vectors into smaller dimensions
Inference-Time Optimization
talk about prefill/decode, kv cache, batching, quantization (encoding vectors)