The tokenizer, Byte-Pair Encoding in this instance, translates each token in the input text into a corresponding token ID. Then, GPT-2 uses these token IDs as input and tries to predict the next most likely token. Finally, the model generates logits, which are converted into probabilities using a softmax function.

For example, the model assigns a probability of 17% to the token for “of” being the next token after “I have a dream”. This output essentially represents a ranked list of potential next tokens in the sequence. More formally, we denote this probability as *P(of | I have a dream) = 17%*.

Autoregressive models like GPT predict the next token in a sequence based on the preceding tokens. Consider a sequence of tokens *w = (w*₁*, w*₂*, …, w*ₜ*)*. The joint probability of this sequence *P(w)* can be broken down as:

For each token *wᵢ* in the sequence, *P(wᵢ | w₁, w₂, …, wᵢ₋₁)* represents the conditional probability of *wᵢ* given all the preceding tokens (*w₁, w₂, …, wᵢ₋₁*). GPT-2 calculates this conditional probability for each of the 50,257 tokens in its vocabulary.

This leads to the question: how do we use these probabilities to generate text? This is where decoding strategies, such as greedy search and beam search, come into play.

Greedy search is a decoding method that takes the most probable token at each step as the next token in the sequence. To put it simply, it only retains the most likely token at each stage, discarding all other potential options. Using our example:

**Step 1**: Input: “I have a dream” → Most likely token: “ of”**Step 2**: Input: “I have a dream of” → Most likely token: “ being”**Step 3**: Input: “I have a dream of being” → Most likely token: “ a”**Step 4**: Input: “I have a dream of being a” → Most likely token: “ doctor”**Step 5**: Input: “I have a dream of being a doctor” → Most likely token: “.”

While this approach might sound intuitive, it’s important to note that the greedy search is short-sighted: it only considers the most probable token at each step without considering the overall effect on the sequence. This property makes it fast and efficient as it doesn’t need to keep track of multiple sequences, but it also means that it can miss out on better sequences that might have appeared with slightly less probable next tokens.

Next, let’s illustrate the greedy search implementation using graphviz and networkx. We select the ID with the highest score, compute its log probability (we take the log to simplify calculations), and add it to the tree. We’ll repeat this process for five tokens.

`import matplotlib.pyplot as plt`

import networkx as nx

import numpy as np

import timedef get_log_prob(logits, token_id):

# Compute the softmax of the logits

probabilities = torch.nn.functional.softmax(logits, dim=-1)

log_probabilities = torch.log(probabilities)

# Get the log probability of the token

token_log_probability = log_probabilities[token_id].item()

return token_log_probability

def greedy_search(input_ids, node, length=5):

if length == 0:

return input_ids

outputs = model(input_ids)

predictions = outputs.logits

# Get the predicted next sub-word (here we use top-k search)

logits = predictions[0, -1, :]

token_id = torch.argmax(logits).unsqueeze(0)

# Compute the score of the predicted token

token_score = get_log_prob(logits, token_id)

# Add the predicted token to the list of input ids

new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)

# Add node and edge to graph

next_token = tokenizer.decode(token_id, skip_special_tokens=True)

current_node = list(graph.successors(node))[0]

graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100

graph.nodes[current_node]['token'] = next_token + f"_{length}"

# Recursive call

input_ids = greedy_search(new_input_ids, current_node, length-1)

return input_ids

# Parameters

length = 5

beams = 1

# Create a balanced tree with height 'length'

graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())

# Add 'tokenscore', 'cumscore', and 'token' attributes to each node

for node in graph.nodes:

graph.nodes[node]['tokenscore'] = 100

graph.nodes[node]['token'] = text

# Start generating text

output_ids = greedy_search(input_ids, 0, length=length)

output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)

print(f"Generated text: {output}")

`Generated text: I have a dream of being a doctor.`

Our greedy search generates the same text as the one from the transformers library: “I have a dream of being a doctor.” Let’s visualize the tree we created.

`import matplotlib.pyplot as plt`

import networkx as nx

import matplotlib.colors as mcolors

from matplotlib.colors import LinearSegmentedColormapdef plot_graph(graph, length, beams, score):

fig, ax = plt.subplots(figsize=(3+1.2*beams**length, max(5, 2+length)), dpi=300, facecolor='white')

# Create positions for each node

pos = nx.nx_agraph.graphviz_layout(graph, prog="dot")

# Normalize the colors along the range of token scores

if score == 'token':

scores = [data['tokenscore'] for _, data in graph.nodes(data=True) if data['token'] is not None]

elif score == 'sequence':

scores = [data['sequencescore'] for _, data in graph.nodes(data=True) if data['token'] is not None]

vmin = min(scores)

vmax = max(scores)

norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

cmap = LinearSegmentedColormap.from_list('rg', ["r", "y", "g"], N=256)

# Draw the nodes

nx.draw_networkx_nodes(graph, pos, node_size=2000, node_shape='o', alpha=1, linewidths=4,

node_color=scores, cmap=cmap)

# Draw the edges

nx.draw_networkx_edges(graph, pos)

# Draw the labels

if score == 'token':

labels = {node: data['token'].split('_')[0] + f"n{data['tokenscore']:.2f}%" for node, data in graph.nodes(data=True) if data['token'] is not None}

elif score == 'sequence':

labels = {node: data['token'].split('_')[0] + f"n{data['sequencescore']:.2f}" for node, data in graph.nodes(data=True) if data['token'] is not None}

nx.draw_networkx_labels(graph, pos, labels=labels, font_size=10)

plt.box(False)

# Add a colorbar

sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

sm.set_array([])

if score == 'token':

fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Token probability (%)')

elif score == 'sequence':

fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Sequence score')

plt.show()

# Plot graph

plot_graph(graph, length, 1.5, 'token')