Decoding Strategies in Large Language Models


The tokenizer, Byte-Pair Encoding in this instance, translates each token in the input text into a corresponding token ID. Then, GPT-2 uses these token IDs as input and tries to predict the next most likely token. Finally, the model generates logits, which are converted into probabilities using a softmax function.

For example, the model assigns a probability of 17% to the token for “of” being the next token after “I have a dream”. This output essentially represents a ranked list of potential next tokens in the sequence. More formally, we denote this probability as P(of | I have a dream) = 17%.

Autoregressive models like GPT predict the next token in a sequence based on the preceding tokens. Consider a sequence of tokens w = (w, w, …, w). The joint probability of this sequence P(w) can be broken down as:

For each token wᵢ in the sequence, P(wᵢ | w₁, w₂, …, wᵢ₋₁) represents the conditional probability of wᵢ given all the preceding tokens (w₁, w₂, …, wᵢ₋₁). GPT-2 calculates this conditional probability for each of the 50,257 tokens in its vocabulary.

This leads to the question: how do we use these probabilities to generate text? This is where decoding strategies, such as greedy search and beam search, come into play.

Greedy search is a decoding method that takes the most probable token at each step as the next token in the sequence. To put it simply, it only retains the most likely token at each stage, discarding all other potential options. Using our example:

  • Step 1: Input: “I have a dream” → Most likely token: “ of”
  • Step 2: Input: “I have a dream of” → Most likely token: “ being”
  • Step 3: Input: “I have a dream of being” → Most likely token: “ a”
  • Step 4: Input: “I have a dream of being a” → Most likely token: “ doctor”
  • Step 5: Input: “I have a dream of being a doctor” → Most likely token: “.”

While this approach might sound intuitive, it’s important to note that the greedy search is short-sighted: it only considers the most probable token at each step without considering the overall effect on the sequence. This property makes it fast and efficient as it doesn’t need to keep track of multiple sequences, but it also means that it can miss out on better sequences that might have appeared with slightly less probable next tokens.

Next, let’s illustrate the greedy search implementation using graphviz and networkx. We select the ID with the highest score, compute its log probability (we take the log to simplify calculations), and add it to the tree. We’ll repeat this process for five tokens.

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import time

def get_log_prob(logits, token_id):
# Compute the softmax of the logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
log_probabilities = torch.log(probabilities)

# Get the log probability of the token
token_log_probability = log_probabilities[token_id].item()
return token_log_probability

def greedy_search(input_ids, node, length=5):
if length == 0:
return input_ids

outputs = model(input_ids)
predictions = outputs.logits

# Get the predicted next sub-word (here we use top-k search)
logits = predictions[0, -1, :]
token_id = torch.argmax(logits).unsqueeze(0)

# Compute the score of the predicted token
token_score = get_log_prob(logits, token_id)

# Add the predicted token to the list of input ids
new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)

# Add node and edge to graph
next_token = tokenizer.decode(token_id, skip_special_tokens=True)
current_node = list(graph.successors(node))[0]
graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
graph.nodes[current_node]['token'] = next_token + f"_{length}"

# Recursive call
input_ids = greedy_search(new_input_ids, current_node, length-1)

return input_ids

# Parameters
length = 5
beams = 1

# Create a balanced tree with height 'length'
graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())

# Add 'tokenscore', 'cumscore', and 'token' attributes to each node
for node in graph.nodes:
graph.nodes[node]['tokenscore'] = 100
graph.nodes[node]['token'] = text

# Start generating text
output_ids = greedy_search(input_ids, 0, length=length)
output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)
print(f"Generated text: {output}")

Generated text: I have a dream of being a doctor.

Our greedy search generates the same text as the one from the transformers library: “I have a dream of being a doctor.” Let’s visualize the tree we created.

import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap

def plot_graph(graph, length, beams, score):
fig, ax = plt.subplots(figsize=(3+1.2*beams**length, max(5, 2+length)), dpi=300, facecolor='white')

# Create positions for each node
pos = nx.nx_agraph.graphviz_layout(graph, prog="dot")

# Normalize the colors along the range of token scores
if score == 'token':
scores = [data['tokenscore'] for _, data in graph.nodes(data=True) if data['token'] is not None]
elif score == 'sequence':
scores = [data['sequencescore'] for _, data in graph.nodes(data=True) if data['token'] is not None]
vmin = min(scores)
vmax = max(scores)
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
cmap = LinearSegmentedColormap.from_list('rg', ["r", "y", "g"], N=256)

# Draw the nodes
nx.draw_networkx_nodes(graph, pos, node_size=2000, node_shape='o', alpha=1, linewidths=4,
node_color=scores, cmap=cmap)

# Draw the edges
nx.draw_networkx_edges(graph, pos)

# Draw the labels
if score == 'token':
labels = {node: data['token'].split('_')[0] + f"n{data['tokenscore']:.2f}%" for node, data in graph.nodes(data=True) if data['token'] is not None}
elif score == 'sequence':
labels = {node: data['token'].split('_')[0] + f"n{data['sequencescore']:.2f}" for node, data in graph.nodes(data=True) if data['token'] is not None}
nx.draw_networkx_labels(graph, pos, labels=labels, font_size=10)
plt.box(False)

# Add a colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
if score == 'token':
fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Token probability (%)')
elif score == 'sequence':
fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Sequence score')
plt.show()

# Plot graph
plot_graph(graph, length, 1.5, 'token')



Source link

Leave a Comment