Introduction

What is the LLM Playbook?

Given the explosive growth and diverse range of methodologies in the field of large language models (LLMs), there's an inherent need for structured and clear communication. That's where this playbook comes in. Unlike more technical blogs or exhaustive resources that delve deep into mathematical rigor, this playbook serves a unique purpose. It is primarily a platform where I can systematize and articulate my understanding of the rapid advancements of LLM training, optimization, and deployment. In collating my observations and insights, I aim to bring clarity to an area that is complex and ever-changing.

While this playbook is invaluable for my own cognitive structuring, it is also intended to be a resource for others. Whether you are a newcomer looking for a guided introduction or a seasoned practitioner seeking up-to-date insights, this document aims to provide a curated view of the key developments shaping the future of large language models. Given my background as a medical doctor, you won't find an abundance of math-heavy equations or theoretical proofs here. Instead, the approach is designed to be intuitive, aiming to make the subject matter accessible to a broader audience. That said, I do assume that you have a basic understanding of Python and deep learning, as (relatively unoptimized) code snippets and examples will frequently be used to illustrate points.

Thank you for joining this educational journey, and I hope you find the playbook as enlightening as I find the process of maintaining it.

About Me

My name is Cyril Zakka, I'm a medical doctor and postdoctoral fellow in the Hiesinger Lab in the Department of Cardiothoracic Surgery at Stanford University. My research interests primarily involve building and deploying large multimodal networks for medical imaging and autonomous robotic surgery.

If you have any feedback, comments, or questions please don't hesitate to reach out:

Summary

Introduction

Table of Contents

Positional Embeddings

The transformer architecture has revolutionized the field of natural language processing, but it comes with a peculiar limitation: it lacks an intrinsic mechanism to account for the position or sequence order of elements in an input. In plain terms, a transformer model would produce the same output for two different permutations of the same input sequence. This is problematic because the sequence in which words or tokens appear carries significant meaning in language and other types of data.

This limitation arises because the architecture relies on self-attention mechanisms, which, by their very design, are permutation-invariant—they treat all positions equally and thus are indifferent to the arrangement of elements in the sequence. Consequently, while transformers excel at recognizing patterns and relationships between elements, they are blind to the "where" and "when" of those elements within the sequence.

To address this shortcoming and make transformers aware of element positions, we use a specialized form of embeddings known as positional embeddings. These embeddings work alongside the standard word embeddings to grant transformers the capability to understand sequence order. By doing so, they complete the picture, allowing the model to interpret data in a way that respects both content and sequence.

As with all aspects of machine learning, the choice of position encoding typically involves tradeoffs between simplicity, flexibility, and efficiency. Here we explore a few of the most popular methods:

Absolute Positional Encoding

Absolute position encodings are computed in the input layer and are summed with the input token embeddings. Vaswani et al. (2017) proposed this for Transformers and it has been a popular choice in the followup works (Radford et al., 2018; Devlin et al., 2018). There are two common variations of the absolute position encodings - fixed and learned.

Relative Positional Encoding

One drawback of absolute position encoding is that it requires fixed length of input sequence and does not directly capture relative positions to each word. To solve these problems several relative positions schemes have been proposed.

Fixed Positional Embeddings

As previously mentioned, in the Transformer architecture positional encodings serve as a critical component for giving the model an understanding of the order of tokens in a sequence. Unlike recurrent networks, which inherently understand sequence order, the multi-head attention mechanism in the Transformer is non-recurrent and processes the entire sequence in parallel. Consequently, it lacks an innate sense of order among the data points.

To remedy this, the concept of positional encoding is employed. Specifically, a tensor that matches the shape of the input sequence is added to the input, and this tensor is designed such that the difference in values between any two positions reflects their distance in the sequence. This allows the model to understand the relative positions of tokens and treat them accordingly.

To this end, several methods for positional encoding have been proposed. Before we dive into more advanced methods for positional encoding, let's first debunk the shortcomings of seemingly intuitive solutions. Up first, one might think about normalizing time-step values between [0, 1] and using them for positional information:

time_step_normalized = np.linspace(0, 1, num_tokens)

Though tempting, this approach is inherently flawed: the normalized values are dependent on sequence length, making it problematic for the model to handle sequences of varying lengths - a positional encoding value of 0.4 means something entirely different to a sequence of length 4 than to a sequence of length 80.

Similarly, one might advocate for a linear numbering scheme such as:

time_step_linear = np.arange(1, num_tokens + 1)

Simple? Yes. Effective? Not quite. As sequence length inflates, positional values escalate, undermining the model's ability to generalize to sequences longer than those in the training set, while potentially leading to training instabilities (e.g. exploding gradients).

Sinusoidal Positional Encoding

Among the various approaches proposed over time, the most widely used form of fixed positional embeddings is sinusoidal positional encoding. In this method, each position in the sequence is uniquely represented by a combination of sine and cosine functions at different frequencies. These sinusoidal embeddings are added to the input embeddings to supplement them with positional context.

def sinusoidal_positional_encoding(position, d_model):
    angle_rads = np.arange(d_model) // 2 * np.pi / np.power(10000, 2 * (np.arange(d_model) // 2) / np.float32(d_model)) # 1
    angle_rads = position * angle_rads # 2
    pos_encoding = np.zeros(d_model) # 3
    pos_encoding[0::2] = np.sin(angle_rads[0::2]) # 4
    pos_encoding[1::2] = np.cos(angle_rads[1::2]) # 4
    return pos_encoding

Here, the function takes two arguments: position representing the position of a token in the sequence, and d_model being the dimension of the model's input embeddings.

  1. Initialize Angle Array: We start by creating an array that will hold angle values for sine and cosine functions. These angles are calculated in such a way that they depend on both the position of a token in the sequence and its position in the embedding space. The calculations involve some scaling to ensure that the model handles different sequence lengths efficiently.
  2. Position-Based Scaling: The next step is to multiply these pre-calculated angle values by the position of the token in the sequence. This ensures that each token position will have a unique set of angles.
  3. Initialize Encoding Array: An array of zeros is then initialized. This array will hold the final positional encodings and has the same size as the embedding dimension of the model.
  4. Populate Sine and Cosine Values: Finally, we populate this zero array with sine and cosine values based on the angle values we've computed. The sine values go into the even-indexed positions, and the cosine values go into the odd-indexed positions. The end result is that each position in the sequence gets a unique pattern of sine and cosine values, making it distinguishable from other positions.

How exactly does this approach convey positional information? Imagine a series of pendulums aligned in a straight line. Each pendulum is swinging at a different frequency, starting from the leftmost one, which swings the slowest, to the rightmost one, which swings the fastest. Now, imagine taking a snapshot of the pendulums at a certain time t wheret corresponds to the token's position in the sequence.

In this snapshot, pendulums on the left have moved very little due to their slower frequencies, while those on the right have moved considerably. If you were to calculate the dot product (read: similarity) of their positions at this moment, the slow-swinging pendulums would be aligned closely and contribute positively to the dot product. In contrast, the fast-swinging pendulums would be out of phase and contribute noise around zero to the dot product.

As time (or position) progresses, the snapshot would capture more pendulums being out of phase, causing the dot product value to gradually converge to zero. This mirrors the behavior of the sinusoidal positional encoding: the dot product between the positional encodings of tokens that are close in sequence will be high, while the value will smoothly decrease for tokens that are further apart.

By mapping each token's position in the sequence to a unique combination of sinusoidal values, we effectively capture the relative positions and relationships between tokens. The encoded values at different positions can then be visualized, showing a high value for nearby tokens and a smoothly decreasing value as the distance between tokens increases.

Learned Positional Embeddings

In contrast to fixed positional embeddings like sinusoidal encoding, another popular approach is learned positional embeddings. Here, instead of hard-coding the logic for computing positional encodings, we make the model learn the best possible representation for sequence position during the training phase. Like other parameters in the model, these learned positional embeddings get fine-tuned through backpropagation.

The learned positional embeddings offer the model flexibility and adaptability. They can be designed to have the same shape as the input sequence, thus making them directly addable to the token embeddings. In it simplest form, learned positional embeddings can be defined as:

pos_emb_shape = (1, seq_len, d_model) # 1
pos_embedding = np.random.randn(*pos_emb_shape) # 2
x += pos_embedding # 3
  1. Initialize Embedding Shape: The first line of code sets up the shape for the positional embedding array. The shape (1, seq_len, d_model) indicates that we'll have: 1 to denote it's a single tensor that will be broadcasted across multiple batches, seq_len as the length of the sequence to which the positional embedding will be added, and d_model as the dimensions of the model, which should match the dimension of the input sequence embeddings. This ensures that we can add the positional embedding directly to the token embeddings.
  2. Random Initialization: In the second line, we initialize the positional embedding array with random values from a normal distribution. This serves as a starting point for what the model will later refine during training. These embeddings are considered parameters and are fine-tuned during the backpropagation process.
  3. Add to Input Sequence: Finally, we add the positional embedding array to the input sequence x. This is done element-wise and serves to encode the position information within each token's embedding. This combined representation is then passed through the model for further processing.

Rotary Positional Embeddings

Rotary Positional Embeddings aim to overcome limitations tied to both fixed and learned positional embeddings. While fixed sinusoidal embeddings are generalizable to arbitrary sequence lengths in practice, models have been found to underperform when encountering sequences with lengths substantially different from their training data in practice. Enter rotary positional embeddings.

Rotary Positional Embeddings provide a flexible mechanism to include positional context into tokens, without modifying the original embeddings. The core principle revolves around rotating the queries and keys in the attention mechanism, where each position in the sequence receives a unique rotation. This way, the dot product between queries and keys gradually diminishes for tokens that are distant from one another in the sequence, providing an effective way to encode relative positions.

This approach tends to maintain more of the original token information while still providing the model with an effective way to understand sequence positions. Their implementation would look something like:

def rotary_positional_embedding(position, d_model):
    freqs = np.exp(np.linspace(0., -1., d_model // 2) * np.log(10000.)) # 1
    angles = position * freqs # 2
    rotary_matrix = np.stack([np.sin(angles), np.cos(angles)], axis=-1) # 3
    return rotary_matrix.reshape(-1, d_model) # 4
  1. Initialize Frequency Array: Similar to the sinusoidal approach, we initiate an array of frequencies. The key difference here is the use of exponential scaling to generate frequencies, which will serve as rotation factors.
  2. Position-Based Scaling: Next, we scale the positions by these frequencies. Unlike in sinusoidal encodings where the scaled positions would be added to the embeddings, here they are used for rotating the embeddings.
  3. Construct Rotary Matrix: Using the scaled angles, a rotary matrix is created by stacking the sine and cosine of the angles. This matrix will serve to rotate the original embeddings.
  4. Reshape Rotary Matrix: Finally, the rotary matrix is reshaped to match the model's embedding dimension, ensuring it's appropriately utilized to rotate the token embeddings. This rotation matrix is then embedded into the original vector by matrix multiplication instead of addition.

Simple enough! Let's conceptualize rotary positional embeddings by imagining a clock with multiple hands. Each hand rotates at a different speed, representing different frequencies. Every token in your sequence corresponds to a specific clock hand.

  • Variable Rotation Speed: Just like in a real clock where the second, minute, and hour hands rotate at distinct speeds, different dimensions in the query/key embeddings are rotated differently. This can be thought of as each dimension having its own "frequency," determining how fast it rotates based on its position in the sequence.
  • Dot Product Significance: When two clock hands point in the same or similar direction (i.e., their angles are close), they can be considered "similar" or "close" in sequence context. In the same vein, the dot product between rotated queries and keys would be higher for positions that are close in the sequence, and lower for positions that are farther apart. As time progresses (or as you traverse through the sequence), each clock hand rotates based on its speed (frequency). When you look at the clock at any given "time" (or position in the sequence), the angles of the clock hands with respect to a fixed starting point provide a snapshot of the tokens' positions.
  • Invariance to Sequence Length: Much like how the hands of a clock keep rotating indefinitely regardless of the 12-hour clock face, Rotary Positional Embeddings aren't restricted by the length of the sequence. This means they can adapt to sequences of varying lengths, offering a level of flexibility.
  • Impact on Attention: Just as you could determine the elapsed time between different events by observing the relative angles between clock hands, rotary positional embeddings influence the attention mechanism. They help it focus on tokens that are contextually relevant to each other based on their positional relationships in the sequence. By simply looking at how much each hand has rotated, you can figure out its relative position in the sequence. In this way, the rotational information captures the essence of each token's position within the overall sequence while leaving the actual token embeddings largely untouched.

In rotary positional embeddings, the same principle applies: each token's embedding gets "rotated" based on its position in the sequence. This rotational change encodes the positional information while retaining the original embedding, thus allowing the model to understand the tokens' relative positions effectively.

Attention

Imagine you're reading a dense academic paper, thrilling novel or LLM playbook. Your brain doesn't weigh each word or sentence equally. Some portions are scrutinized carefully, while others may be skimmed over. You naturally pay 'attention' to specific parts based on their relevance to your current focus or understanding. This selective focus allows you to better comprehend the text and keeps you from being overwhelmed by unnecessary details.

In the realm of machine learning, particularly in sequence-to-sequence tasks like machine translation or text summarization, a similar mechanism is invaluable. Early models like RNNs and LSTMs process sequences step-by-step, but they can struggle with long sequences and often lose track of important earlier tokens when focused on the more recent ones. The attention mechanism was introduced to combat these limitations. It essentially allows the model to "focus" on different parts of the input sequence when producing an output, much like how you would focus on certain parts of a text while reading.

The Attention Mechanism Explained

Within this framework, three major components come into play: Query (Q), Key (K), and Value (V).

  • Query (Q): Picture this as the search term you'd type into a search engine—like asking your brain, "Hey, what should I focus on right now?" It's a vector representing your current area of focus.
  • Key (K): The Key vectors are akin to the titles of Wikipedia articles. They serve as guideposts, each corresponding to a specific token in the input sequence, hinting where you might find relevant information.
  • Value (V): These vectors are the meat of the matter—the article content, if you will. They offer the detailed information each token carries.

How do they all come together?

class Attention(nn.Module):
    def __init__(self, word_size:int=512, embed_dim:int=64) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.dim_K = torch.tensor(embed_dim)
        self.query = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
        self.key  = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
        self.value = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)

    def self_attention(self, Q:Tensor, K:Tensor, V:Tensor) -> Tensor:
        K_T = torch.transpose(K, 0, 1)
        score = torch.matmul(Q, K_T)  / torch.sqrt(self.dim_K)
        score = torch.softmax(score, dim=-1)
        Z = torch.matmul(score, V)
        return Z

    def forward(self, x:Tensor) -> Tensor:
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        Z = self.self_attention(Q, K, V)
        return Z
  1. Score Generation: Every Query is matched against all Keys via a dot product to calculate a score. It's like asking, "How relevant is this part of the text to my current focus?"
  2. Softmax Scaling: These scores undergo Softmax normalization, turning these scores into probabilities that sum up to 1. Picture this as divvying up your concentration across various parts.
  3. Weighted Sum: Finally, these attention weights are applied to the Values (you guessed it, another dot product), summing them up to form a single output vector. You can think of this as gathering all the most valuable sentences to form a cohesive summary of what you need to focus on.

Self-Attention

Self-attention is a specific type of attention mechanism where the Query (Q), Key (K), and Value (V) all come from the same place, usually the output of the previous layer in your network. In layman's terms, self-attention enables tokens in the input sequence to look at other tokens in the same sequence to gather contextual clues.

Self-attention becomes particularly interesting when employed in autoregressive models like GPT (Generative Pre-trained Transformer). In such models, the generation of each new token is dependent only on the preceding tokens. Causal self-attention restricts the scope of attention in a way that each token only looks at those that precede it, and not the ones that follow. This is crucial for maintaining the sequence structure and generating sensible outputs. Imagine it like reading a book where you only consider the chapters or sentences you've already read to make a prediction about what comes next. You don't peek ahead; you stick with what you know so far.

Limitations and Challenges

While attention mechanisms have revolutionized sequence-to-sequence tasks, they're not without their challenges:

  • Computational Cost: Attention mechanisms can be computationally expensive, especially for very long sequences. As such there exists a rich body of literature extending and refining it in various ways. We'll explore a few of those approaches in the subsequent sections.
  • Interpretability: While attention weights can give some insight into what the model is "focusing" on, this doesn't necessarily mean the model "understands" the text in the way humans do.

Multi-Headed Attention (MHA)

In a single attention mechanism, each token gets a chance to focus on other parts of the sequence. However, there's a limit to what it can capture this way. Multi-headed attention solves this by running not one but multiple attention layers in parallel, essentially allowing the model to pay attention to different parts of the input for different reasons. This can naively be implemented in the following way:

class MultiheadAttention(nn.Module):
    r"""
    https://arxiv.org/abs/1706.03762
    """
    def __init__(self, word_size: int = 512, embed_dim: int = 64, n_head:int=8) -> None:
        super().__init__()
        self.n_head = n_head
        self.embed_dim = embed_dim
        self.dim_K = torch.tensor(embed_dim)
        self.proj = nn.Parameter(torch.empty(embed_dim * n_head, embed_dim))
        nn.init.xavier_uniform_(self.proj)
        self.multihead = nn.ModuleList([
            Attention(word_size, embed_dim) for _ in range(n_head)
        ])

    def forward(self, x: Tensor) -> Tensor:
        Z_s = torch.cat([head(x) for head in self.multihead], dim=1)
        Z = torch.matmul(Z_s, self.proj)
        return Z

Why is this useful?

  • Diverse Representations: Having multiple heads allows the model to recognize various types of relationships between tokens, which can be critical for understanding complex structures like sentences.
  • Increased Capacity: Multi-headed attention increases the model's capacity to learn, as each head can potentially learn different aspects of the data. Think of it like having multiple detectives on the case instead of just one.
  • Parallelism: Multiple heads can be processed in parallel, providing a computational advantage. Imagine splitting the detective work, where each detective specializes in a different type of evidence.

Keep in mind that the number of heads and their dimensions are hyperparameters that you'll have to fine-tune based on your specific application. More heads are not always better; it's about striking the right balance between model complexity and performance.

Multi-Query Attention (MQA)

Multi-Query Attention (MQA) is a refined version of the Multi-Head Attention (MHA) algorithm that improves computational efficiency without sacrificing much in terms of model accuracy. In standard MHA, separate linear transformations are applied to the Query (Q), Key (K), and Value (V) for each attention head. MQA diverges from this by using a single shared set of Keys (K) and Values (V) across all heads, while allowing individual transformations for each Query (Q). Although this approach was first introduced in 2019, it has only been recently popularized by models such as PaLM and Falcon. This is illustrated below:

class  MultiQueryAttention(Attention):
    r"""
    https://arxiv.org/pdf/1911.02150.pdf
    """
    def __init__(self, word_size: int = 512, embed_dim: int = 64, n_query:int=8) -> None:
        super().__init__(word_size, embed_dim)
        self.n_query = n_query
        self.proj = nn.Parameter(torch.empty(embed_dim * n_query, embed_dim))
        nn.init.xavier_normal_(self.proj)
        delattr(self, 'query')
        self.querys = nn.ModuleList([
            nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
            for _ in range(n_query)
        ])
        self.key = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)
        self.value = nn.Linear(in_features=word_size, out_features=embed_dim, bias=True)

    def forward(self, x: Tensor) -> Tensor:
        K = self.key(x)
        V = self.value(x)
        Z_s = torch.cat([
            self.self_attention(query(x), K, V) for query in self.querys
        ], dim=1)
        Z = torch.matmul(Z_s, self.proj)
        return Z

with improvements in:

  • Memory Space: Sharing K and V across all heads dramatically reduces the memory footprint. This is critical for handling long sequences without choking your hardware.
  • Memory Bandwidth: With fewer unique transformations, the computational cost in terms of memory bandwidth also drops.

Grouped-Query Attention (GQA)

Grouped Query Attention (GQA) extends the concept of Multi-Head Attention (MHA) and Multi-Query Attention (MQA) by providing a flexible trade-off between computational efficiency and model expressiveness. In GQA, query heads are divided into G groups, where each group shares a common key (K) and value (V) projection. This configuration enables three notable variations:

  • GQA-1: A single group, which equates to Multi-Query Attention (MQA).
  • GQA-H: Groups equal to the number of heads, essentially the same as Multi-Head Attention (MHA).
  • GQA-G: An intermediate configuration with G groups, balancing between efficiency and expressiveness.

The use of G groups allows GQA to mitigate the memory overhead associated with storing keys and values for each head, especially in scenarios with large context windows or batch sizes. At the same time, it offers a nuanced control over the model's quality and efficiency.

In its simplest form, GQA can be implemented as follows:

class  GroupedQueryAttention(Attention):
    r"""
    https://arxiv.org/pdf/2305.13245.pdf
    """
    def __init__(self, word_size: int = 512, embed_dim: int = 64,
                 n_grouped: int = 4, n_query_each_group:int=2) -> None:
        super().__init__(word_size, embed_dim)
        delattr(self, 'query')
        delattr(self, 'key')
        delattr(self, 'value')

        self.grouped = nn.ModuleList([
            MultiQueryAttention(word_size, embed_dim, n_query=n_query_each_group)
            for _ in range(n_grouped)
        ])
        # self.proj = nn.Parameter(torch.empty((..., ...), requires_grad=True))
        self.proj = nn.Parameter(torch.empty(embed_dim * n_grouped, embed_dim))
        nn.init.xavier_uniform_(self.proj)

    def forward(self, x: Tensor) -> Tensor:
        Z_s = torch.cat([head(x) for head in self.grouped], dim=1)
        Z = torch.matmul(Z_s, self.proj)
        return Z

Sliding-Window Attention

The original Transformer's self-attention component has a computational complexity of O(n^2) with n being the input sequence length. In other words, if the input sequence size doubles, the time taken to compute self-attention quadruples. This inefficiency becomes a roadblock when handling extensive input sequences, making it impractical for large-scale tasks. Sliding Window Attention (SWA) addresses this by employing a fixed-size window around each token, reducing the computational overhead while retaining the ability to consider local context.

How does it work?

Sliding Window Attention uses a fixed-size window w around each token in the sequence. Specifically, each token attends to 0.5*w tokens on both sides of itself. This localizes the attention span and reduces the time complexity to O(n*w), a linear function with respect to the sequence length n. In autoregressive contexts, each token attends to w tokens before it.

Since a single layer of SWA has its limitation can only capture local context within its fixed window, multiple layers are stacked upon each other, effectively increasing the receptive field without incurring an exponential increase in computational cost. The receptive field size is determined by l*w where l is the number of layers. Optionally, w can be varied across layers to find a suitable trade-off between computational efficiency and model expressiveness.

Attention Sink

Attention Sinks address a critical issue observed in the use of window attention in autoregressive language models. When window attention is applied, these models often exhibit a sudden decline in fluency as soon as the first token leaves the context window. The underlying reason for this decline lies in an intriguing aspect of LLMs: an overwhelming majority of attention is allocated to the first few tokens of the sequence, termed as "attention sinks." These tokens soak up a disproportionate amount of the attention score—even when they are not semantically relevant.

Why does this happen?

The model relies heavily on these "sink" tokens because the softmax operation in the attention mechanism enforces a sum-to-one constraint. In the absence of relevant tokens to match with the next generated token, the model compensates by dumping attention scores into these first few tokens. When window attention is employed and the first token exits the window, the model loses its default 'sink' to offload the attention. This leads to the attention scores being dispersed across all remaining tokens. Consequently, tokens that should not necessarily have high attention scores end up getting them, causing the model to "collapse" and lose fluency.

The Solution

To mitigate this issue, the authors propose an adaptation to the traditional window attention. The revised model always keeps the initial four tokens—i.e., the attention sink tokens—within the window. Moreover, instead of using positions from the original text, they use the positions within the cache to add positional information to the tokens. This ensures that the "sink" tokens remain spatially close to the rest, effectively serving as attention offloading points.

KV Cache

Key-Value (KV) caching is a technique used to accelerate the inference process in machine learning models, particularly in autoregressive models. In these models, generating tokens one by one is a common practice, but it can be computationally expensive because it repeats certain calculations at each step. To address this, KV caching comes into play. It involves caching the previous keys and values, so we don’t need to recalculate them for each new token. This significantly reduces the size of matrices used in calculations, making matrix multiplications faster. The only trade-off is that KV caching requires more GPU memory (or CPU memory if a GPU isn’t used) to store these states.

class KVCache:
    def __init__(self, max_batch_size, max_seq_len, n_kv_heads, head_dim, device):
        self.cache_k = torch.zeros((max_batch_size, max_seq_len, n_kv_heads, head_dim)).to(device)
        self.cache_v = torch.zeros((max_batch_size, max_seq_len, n_kv_heads, head_dim)).to(device)

    def update(self, batch_size, start_pos, xk, xv):
        self.cache_k[:batch_size, start_pos :start_pos + xk.size(1)] = xk
        self.cache_v[:batch_size, start_pos :start_pos + xv.size(1)] = xv

    def get(self, batch_size, start_pos, seq_len):
        keys = self.cache_k[:batch_size,  :start_pos + seq_len]
        values = self.cache_v[:batch_size, :start_pos + seq_len]
        return keys, values

Sampling

In sequence-to-sequence models like GPT or Transformer-based architectures, generating an output sequence (e.g., text) involves making a series of choices for each element in the sequence. The method by which we make these choices is termed as 'sampling.' Various sampling techniques can be employed, each with its own set of advantages and trade-offs. In this post, we'll zero in on greedy sampling and beam search.

Greedy Sampling

In greedy sampling, the word with the highest conditional probability is selected as the next word in the sequence, given the previous words.

def greedy_sampling(model, input_sequence):
    output_sequence = []
    for i in range(MAX_LENGTH):
        next_word_probabilities = model.predict(input_sequence)
        next_word = argmax(next_word_probabilities)
        output_sequence.append(next_word)
        input_sequence = update_input(input_sequence, next_word)
    return output_sequence
  • Advantages: It's computationally efficient and straightforward to implement.
  • Limitations: Greedy sampling often results in suboptimal and repetitive sequences. Since it doesn't explore other probable words, it can get stuck in a 'local optimum.'

Beam search is an extension of greedy search that aims to improve the quality of the generated sequences by maintaining a 'beam' of the most promising partial sequences at each decoding step. The core principle of beam search is to keep track of not just a single best prediction at each time step, but a fixed number, B of best predictions. At each time step, the algorithm considers expanding each of these B sequences with all possible next elements and retains only the top B sequences based on their probabilities up to the current time step.

Here is a basic NumPy-based function to illustrate a simplified version of beam search:

import numpy as np

def beam_search_decoder(probs, beam_size=3):
    sequences = [[[], 1.0]]  # list of [sequence, sequence_probability]
    
    for prob in probs:  # loop through each time step
        all_candidates = []
        
        for seq, seq_prob in sequences:
            for idx, p in enumerate(prob):
                candidate = [seq + [idx], seq_prob * p]
                all_candidates.append(candidate)
        
        # Sort all candidates by probability
        ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
        
        # Select top-k based on beam size
        sequences = ordered[:beam_size]
    
    return sequences

Here's the breakdown:

  1. Initialization: Begin with a single sequence containing just the start token and with a probability of 1.
  2. Sequence Expansion: At each time step, expand each sequence in the beam by all possible next elements.
  3. Pruning: Sort all possible sequences by their probabilities and retain only the top B sequences.

Beam search strikes a balance between the breadth of exploration and computational expense. It is often used in applications where the quality of the generated sequence is critical and some level of determinism is acceptable.

Limitations and Challenges

  • Search Space: The algorithm still explores a limited space, defined by the beam size. A small B size could yield sub-optimal sequences, while a larger one would be computationally expensive.
  • Length Normalization: Beam search tends to favor shorter sequences over longer ones. Various strategies, like length normalization, have been proposed to mitigate this.

Top-K

When generating text, a language model can predict the next word based on the previous words in the sequence. One approach is to select the word with the highest probability, but this method—known as "greedy decoding"—often results in repetitive and incoherent text. This is where sampling techniques like Top-K sampling come into play.

The idea behind Top-K sampling is quite straightforward: instead of considering all possible next words in the vocabulary, limit the pool to the top-K most likely next words and sample from this narrowed distribution.

Here's how to do Top-K sampling in a simple NumPy function:

import numpy as np

def top_k_sampling(logits, k):
    top_k_indices = np.argsort(logits)[-k:]  # Get indices of top-k logits
    top_k_logits = logits[top_k_indices]  # Get the top-k logits
    top_k_probs = np.exp(top_k_logits) / np.sum(np.exp(top_k_logits))  # Convert logits to probabilities
    selected_index = np.random.choice(top_k_indices, p=top_k_probs)  # Sample from the top-k indices based on the probabilities
    return selected_index

Here's the breakdown:

  1. Select Top-K: Given the logits for the next word, we select the top-K logits, where K is a predetermined hyperparameter.
  2. Convert to Probabilities: We then convert these logits to probabilities using the Softmax function.
  3. Sampling: Finally, we sample the next word from this top-K distribution.

When to Use Top-K Sampling

Top-K sampling is often used when you want a balance between randomness and relevance in the generated text. It allows the model to explore a bit, potentially generating more creative and diverse text while still being more coherent than random sampling.

Limitations and Considerations

  • Hyperparameter Tuning: The choice of K can significantly influence the results. A smaller K will make the output more focused but less creative, while a larger K will make the output more diverse but potentially less relevant.
  • Not Adaptive: The value of K remains constant, meaning the method isn't adaptive to the context of the text being generated. This limitation has led to the development of more advanced sampling techniques like nucleus sampling.

Top-P

While Top-K sampling restricts the sampling pool to the K most likely next words, Top-P sampling, also known as "nucleus sampling," adds a twist. Instead of specifying a set number of top candidates (K), you specify a probability mass (P) and sample only from the smallest group of words that have a collective probability greater than P.

Let's implement Top-P sampling using a NumPy function for better understanding:

import numpy as np

def top_p_sampling(logits, p):
    sorted_indices = np.argsort(logits)  # Sort logits
    sorted_probs = np.exp(logits[sorted_indices]) / np.sum(np.exp(logits))  # Convert sorted logits to probabilities
    cum_probs = np.cumsum(sorted_probs)  # Calculate the cumulative probability
    valid_indices = np.where(cum_probs >= (1 - p))[0]  # Get valid indices where cumulative probability is above threshold
    if len(valid_indices) > 0:
        min_valid_index = valid_indices[0]
        mask = sorted_indices[min_valid_index:]  # Mask for valid logits
    else:
        mask = sorted_indices[-1:]  # If no valid indices, select the last one (highest probability)
    selected_index = np.random.choice(mask)  # Randomly select an index from the valid set
    return selected_index

Here's the step-by-step breakdown:

  1. Sort and Convert: Sort the logits and convert them to probabilities.
  2. Cumulative Sum: Calculate the cumulative sum of the sorted probabilities.
  3. Thresholding: Identify the subset of words whose collective probability mass exceeds the given threshold (P).
  4. Sampling: Randomly sample the next word from this set of valid candidates.

When to Use Top-P Sampling

Top-P sampling is particularly useful when you want more adaptive and context-sensitive text generation. Unlike Top-K, which has a fixed number of candidates, Top-P allows for a variable number of candidates based on the context, making it more flexible.

Limitations and Considerations

  • Computational Cost: The sorting operation increases the computational cost slightly compared to Top-K sampling.
  • Hyperparameter Sensitivity: The choice of P can significantly influence the generated text. A smaller P will make the text more random, while a larger P will make it more deterministic. Top-P sampling provides an adaptive method for balancing the trade-off between diversity and informativeness in generated text. It has gained popularity in several NLP applications, from automated customer service to creative writing aids.

Temperature

Temperature is a hyperparameter used to control the randomness in the probabilistic sampling of tokens (words, in most cases) from a distribution. It's applied to the logits (the raw scores or predictions) before the Softmax operation. Intuitively, you can think of the temperature as a knob to adjust how conservatively or liberally you want to sample the next token.

Here's the basic formula to apply temperature:

import numpy as np

def apply_temperature(logits, temperature):
    logits = logits / temperature  # Apply temperature scaling
    probs = np.exp(logits) / np.sum(np.exp(logits))  # Softmax to get probabilities
    return np.random.choice(np.arange(len(logits)), p=probs)  # Sample from the distribution

Let's break down what happens:

  1. Scaling: The logits are divided by the temperature. Lower temperature (< 1) makes the model more confident in its top choices, whereas a higher temperature (> 1) makes the model more uncertain, effectively flattening the distribution.
  2. Softmax: After scaling, the logits are transformed into probabilities using the Softmax function.
  3. Sampling: Finally, a word is sampled from this distribution.

When to Use Temperature Scaling

Temperature is widely applicable across different sampling methods and provides fine-grained control over the randomness of output text. Whether you are using greedy decoding, Top-K, or nucleus sampling, adding a temperature parameter can help you adjust the output to meet specific quality-diversity criteria.

Limitations and Considerations

  • Hyperparameter Tuning: The choice of temperature can have a significant impact on your results.
  • Context-Insensitive: Temperature scaling is not adaptive to the context, which may or may not be a limitation based on your use-case.

Speculative Sampling

In speculative sampling, we have two models:

  1. A smaller, faster draft model (e.g. DeepMind's 7B Chinchilla model)
  2. A larger, slower target model (e.g. DeepMind's 70B Chinchilla model)

The idea is that the draft model speculates what the output is  steps into the future, while the target model determines how many of those tokens we should accept. Here's an outline of the algorithm:

  1. The draft model decodes  tokens in the regular autoregressive fashion.
  2. We get the probability outputs of the target and draft model on the new predicted sequence.
  3. We compare the target and draft model probabilities to determine how many of the  tokens we want to keep based on some rejection criteria. If a token is rejected, we resample it using a combination of the two distributions and don't accept any more tokens.
  4. If all  tokens are accepted, we can sample an additional final token from the target model probability output.
def max_fn(x):
    x_max = np.where(x > 0, x, 0)
    return x_max / np.sum(x_max)

def speculative_sampling(x, draft_model, target_model, N, K):
    # NOTE: paper indexes arrays starting from 1, python indexes from 0, so
    # we have to add an extra -1 term when indexing using n, T, or t
    n = len(x)
    T = len(x) + N

    while n < T:
        # Step 1: auto-regressive decode K tokens from draft model and get final p
        x_draft = x
        for _ in range(K):
            p = draft_model(x_draft)
            x_draft = np.append(x_draft, sample(p[-1]))

        # Step 2: target model forward passes on x_draft
        q = target_model(x_draft)

        # Step 3: append draft tokens based on rejection criterion and resample
        # a token on rejection
        all_accepted = True
        for _ in range(K):
            i = n - 1
            j = x_draft[i + 1]
            if np.random.random() < min(1, q[i][j] / p[i][j]):  # accepted
                x = np.append(x, j)
                n += 1
            else:  # rejected
                x = np.append(x, sample(max_fn(q[i] - p[i])))  # resample
                n += 1
                all_accepted = False
                break

        # Step 4: if all draft tokens were accepted, sample a final token
        if all_accepted:
            x = np.append(x, sample(q[-1]))
            n += 1

        # just keeping my sanity
        assert n == len(x), f"{n} {len(x)}"

    return x