Demystifying GQA — Grouped Query Attention for Efficient LLM Pre-training | by Bhavin Jawade | Dec, 2023


The variant of multi-head attention powering LLMs like LLaMA-2, Mistral7B, etc.

A “Group” of Llamas (Source — Image created by the author using Dalle-3)

In the previous article on training large-scale models, we looked at LoRA. In this article, we will examine another strategy adopted by different large language models for efficient training — Grouped Query Attention (GQA). In short, Grouped Query Attention (GQA) is a generalization of multi-head attention (MHA) and multi-query attention (MQA) — with each of them being a special case of GQA. Therefore, before we dive into Grouped Query Attention, let’s revisit traditional multi-head attention proposed by Vaswani et al. in the seminal “Attention is All You Need” paper. Following that, we will explore Multi-query attention and how it addresses challenges with MHA. Finally, we will answer the questions “What is GQA?” and “How does it give us the best of both worlds?”

Multi-head attention is a critical component of Transformer models, enabling them to efficiently process and understand complex sequences in tasks like language translation, summarization, and more. To grasp its intricacies, we must delve into the mathematical underpinnings and understand how multiple heads in the attention mechanism function.

The basic attention mechanism computes a weighted sum of values, with weights dependent on a query and a set of keys. Mathematically, this is expressed as:

This is referred to as scaled dot product attention. In this equation, Q (Query) and K (Key) are matrices representing the queries and keys. V (Value) is the matrix for values. “d_k” is the dimensionality of keys, which is used for scaling.

Expanding with Multi-Head Attention (MHA)

Multi-head attention employs multiple ‘heads’ of attention layers, enabling the model to attend to information from different representation subspaces. In each head, there is an independent set of linear layers (projection matrices) for the query, key, and values (this is an important point that we will revisit in GQA). For each head (numbered h):

headʰ = Attention(Q.Wqʰ,K.Wkʰ,V.Wvʰ)

Concatenating Head Outputs

The outputs of the individual heads are concatenated and then linearly transformed.

MultiHead(Q,K,V) = Concat(head¹,head²,…,headʰ) .W

Wᵒ is another weight matrix that linearly transforms the concatenated vector to the final output dimension.

The intuition behind multi-head attention is that by applying the attention mechanism multiple times in parallel, the model can capture different types of relationships in the data.

Illustration depicting scaled dot product attention, multi-head attention within a transformer’s encoder block. (Source — sections of the diagram from attention is all you need paper https://arxiv.org/abs/1706.03762, composition by the author)

However, MHA enables a nuanced understanding of the relationships between different parts of the input. Nevertheless, this complexity comes at a cost — a significant demand on memory bandwidth, especially during decoder inference.

The Memory Bandwidth Challenge in Multi-Head Attention

The crux of the issue lies in the memory overhead. Each decoding step in autoregressive models like Transformers requires loading decoder weights along with all attention keys and values. This process is not only computationally intensive but also memory bandwidth-intensive. As model sizes grow, this overhead also increases, making scaling up an increasingly arduous task.

Multi-query attention (MQA) emerged as a solution to mitigate this bottleneck. The idea is simple yet effective: use multiple query heads but only a single key and value head. This approach significantly reduces the memory load, enhancing inference speed. It has been employed in multiple large-scale models such as PaLM, StarCoder, and Falcon.

In multi-query attention, we average the heads for keys and values so that all query heads share the same key and value head. This is achieved by replicating the mean-pooled “head” H times, where H is the number of query heads.



Source link

Leave a Comment