# Decoding Strategies in Large Language Models

The tokenizer, Byte-Pair Encoding in this instance, translates each token in the input text into a corresponding token ID. Then, GPT-2 uses these token IDs as input and tries to predict the next most likely token. Finally, the model generates logits, which are converted into probabilities using a softmax function.

For example, the model assigns a probability of 17% to the token for “of” being the next token after “I have a dream”. This output essentially represents a ranked list of potential next tokens in the sequence. More formally, we denote this probability as P(of | I have a dream) = 17%.

Autoregressive models like GPT predict the next token in a sequence based on the preceding tokens. Consider a sequence of tokens w = (w, w, …, w). The joint probability of this sequence P(w) can be broken down as:

For each token wᵢ in the sequence, P(wᵢ | w₁, w₂, …, wᵢ₋₁) represents the conditional probability of wᵢ given all the preceding tokens (w₁, w₂, …, wᵢ₋₁). GPT-2 calculates this conditional probability for each of the 50,257 tokens in its vocabulary.

This leads to the question: how do we use these probabilities to generate text? This is where decoding strategies, such as greedy search and beam search, come into play.

Greedy search is a decoding method that takes the most probable token at each step as the next token in the sequence. To put it simply, it only retains the most likely token at each stage, discarding all other potential options. Using our example:

• Step 1: Input: “I have a dream” → Most likely token: “ of”
• Step 2: Input: “I have a dream of” → Most likely token: “ being”
• Step 3: Input: “I have a dream of being” → Most likely token: “ a”
• Step 4: Input: “I have a dream of being a” → Most likely token: “ doctor”
• Step 5: Input: “I have a dream of being a doctor” → Most likely token: “.”

While this approach might sound intuitive, it’s important to note that the greedy search is short-sighted: it only considers the most probable token at each step without considering the overall effect on the sequence. This property makes it fast and efficient as it doesn’t need to keep track of multiple sequences, but it also means that it can miss out on better sequences that might have appeared with slightly less probable next tokens.

Next, let’s illustrate the greedy search implementation using graphviz and networkx. We select the ID with the highest score, compute its log probability (we take the log to simplify calculations), and add it to the tree. We’ll repeat this process for five tokens.

`import matplotlib.pyplot as pltimport networkx as nximport numpy as npimport timedef get_log_prob(logits, token_id):# Compute the softmax of the logitsprobabilities = torch.nn.functional.softmax(logits, dim=-1)log_probabilities = torch.log(probabilities)# Get the log probability of the tokentoken_log_probability = log_probabilities[token_id].item()return token_log_probabilitydef greedy_search(input_ids, node, length=5):if length == 0:return input_idsoutputs = model(input_ids)predictions = outputs.logits# Get the predicted next sub-word (here we use top-k search)logits = predictions[0, -1, :]token_id = torch.argmax(logits).unsqueeze(0)# Compute the score of the predicted tokentoken_score = get_log_prob(logits, token_id)# Add the predicted token to the list of input idsnew_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)# Add node and edge to graphnext_token = tokenizer.decode(token_id, skip_special_tokens=True)current_node = list(graph.successors(node))[0]graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100graph.nodes[current_node]['token'] = next_token + f"_{length}"# Recursive callinput_ids = greedy_search(new_input_ids, current_node, length-1)return input_ids# Parameterslength = 5beams = 1# Create a balanced tree with height 'length'graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())# Add 'tokenscore', 'cumscore', and 'token' attributes to each nodefor node in graph.nodes:graph.nodes[node]['tokenscore'] = 100graph.nodes[node]['token'] = text# Start generating textoutput_ids = greedy_search(input_ids, 0, length=length)output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)print(f"Generated text: {output}")`
`Generated text: I have a dream of being a doctor.`

Our greedy search generates the same text as the one from the transformers library: “I have a dream of being a doctor.” Let’s visualize the tree we created.

`import matplotlib.pyplot as pltimport networkx as nximport matplotlib.colors as mcolorsfrom matplotlib.colors import LinearSegmentedColormapdef plot_graph(graph, length, beams, score):fig, ax = plt.subplots(figsize=(3+1.2*beams**length, max(5, 2+length)), dpi=300, facecolor='white')# Create positions for each nodepos = nx.nx_agraph.graphviz_layout(graph, prog="dot")# Normalize the colors along the range of token scoresif score == 'token':scores = [data['tokenscore'] for _, data in graph.nodes(data=True) if data['token'] is not None]elif score == 'sequence':scores = [data['sequencescore'] for _, data in graph.nodes(data=True) if data['token'] is not None]vmin = min(scores)vmax = max(scores)norm = mcolors.Normalize(vmin=vmin, vmax=vmax)cmap = LinearSegmentedColormap.from_list('rg', ["r", "y", "g"], N=256) # Draw the nodesnx.draw_networkx_nodes(graph, pos, node_size=2000, node_shape='o', alpha=1, linewidths=4, node_color=scores, cmap=cmap)# Draw the edgesnx.draw_networkx_edges(graph, pos)# Draw the labelsif score == 'token':labels = {node: data['token'].split('_')[0] + f"n{data['tokenscore']:.2f}%" for node, data in graph.nodes(data=True) if data['token'] is not None}elif score == 'sequence':labels = {node: data['token'].split('_')[0] + f"n{data['sequencescore']:.2f}" for node, data in graph.nodes(data=True) if data['token'] is not None}nx.draw_networkx_labels(graph, pos, labels=labels, font_size=10)plt.box(False)# Add a colorbarsm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)sm.set_array([])if score == 'token':fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Token probability (%)')elif score == 'sequence':fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Sequence score')plt.show()# Plot graphplot_graph(graph, length, 1.5, 'token')`