Detailed explanation of Multi-Head Attention - the core technology of the model observing language from different dimensions | Daoman PythonAI

#Multi-Head Attention: Let the model observe language from different dimensions

📂 Stage: Stage 3 — Transformer Revolution (Core) 🔗 Pre-requisite: Self-Attention 自注意力计算 · 位置编码 (Positional Encoding)


1. Why is the single-headed attention not enough?

Let’s first quickly review the core logic of Single Head Self-Attention: for each word in the input sequence, a unique query, key, and value vector is generated, and then a global comparison is performed to obtain the final weighted representation of the word.

This "one-stop query" seems efficient, but real language is too complex - the relationship between a single word and other words is often not one-dimensional.

1.1 Give a life-oriented counterexample

For example, take the sentence "Programmer Xiao Wang stayed up late to write a technical blog":

  • From the grammatical structure, "Xiao Wang" is the subject of "write", "blog" is the object of "write", and "stay up late" is the adverbial of "write";
  • From the perspective of semantic roles, "Xiao Wang" and "Programmer" are identity associations, and "Technology" and "Blog" are theme modifications;
  • From the perspective of implied logic, "staying up late" most likely implies that this blog is "rushing to finish, but it may be full of useful information."

If a single head of attention is given only a set of queries, keys, and values, it is likely to focus on one and miss the other: either it will only capture the most obvious subject, predicate, and object, or it will disperse the weight and capture a bunch of irrelevant details, failing to accurately cover all useful multivariate relationships.


2. The core idea of ​​multi-headed attention: divide and conquer

The solution proposed by the Transformer team is very clever: split the single-headed "universal expert" into H "specialized experts" to collaborate in parallel.

2.1 Basic process dismantling

  1. Head Segmentation: The input embedding dimension (denoted asd_model) divided into H parts on average, each part is calledhead_dim(That is to sayd_model = num_heads × head_dim);
  2. Independent projection: Each header has its own exclusive query, key, and value projection matrix, and generates its own subquery, subkey, and subvalue in parallel;
  3. Separate calculation: Each head uses its own sub-vector group to independently perform a self-attention calculation to obtain its own sub-output;
  4. Splicing and fusion: Splice the sub-outputs of all heads back in orderd_modelDimensions;
  5. Linear integration: Use an integration matrix to linearly transform the spliced ​​vectors to obtain the final multi-head attention output.

2.2 Let’s talk again using the analogy of a “team of experts”

We can think of multi-head attention as an NLP semantic analysis team:

  • The first one is the "grammar analyst": only focusing on the structural relationships of the word's subject, predicate, object, attributive, and adverbial;
  • The first 2 is "Entity Identifier": focusing on the identity/theme association between people's names, place names, and item names;
  • The first 3 are "emotional/logical diggers": specifically looking for hidden information such as "staying up late to catch up on manuscripts" and "happy sharing";
  • ...(More heads of different specializations can be set according to needs)

After each expert handles his or her own tasks individually, the team leader (that is, the final linear integration matrix) puts everyone's analysis reports together and sorts them into a complete and comprehensive report, which is the final semantic representation.


3. PyTorch has built-in implementation to get started quickly.

PyTorchnn.MultiheadAttentionAll the core logic has been encapsulated for us. We don’t need to hand-write the steps of projection, segmentation, and splicing ourselves. We can just call it directly.

The following is a single-layer implementation of the Transformer encoder containing multi-head attention, which can be directly copied and tested:

import torch
import torch.nn as nn
import torch.nn.functional as F

# 带多头自注意力的 Transformer 编码器单层
class TransformerEncoderLayerDemo(nn.Module):
    def __init__(
        self,
        d_model: int = 512,      # 输入/输出的嵌入维度
        num_heads: int = 8,       # 多头数量
        d_ff: int = 2048,         # 前馈网络的中间维度
        dropout: float = 0.1,      # 防止过拟合的丢弃率
    ):
        super().__init__()
        
        # 核心:多头自注意力模块
        # 注意:PyTorch 2.x+ 推荐用 batch_first=True,维度顺序更直观
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True,
        )
        
        # 前馈网络(Feed Forward Network)
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        
        # 残差连接 + LayerNorm
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None):
        # 第一步:多头自注意力 + 残差 + LayerNorm
        # self_attn 的输入:query, key, value,这里用自注意力,三个都是x
        # 输出:(多头注意力结果, 注意力权重矩阵,可选是否返回)
        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_output))  # 残差连接:输入 + 处理后的输出

        # 第二步:前馈网络 + 残差 + LayerNorm
        ff_output = self.linear2(F.gelu(self.linear1(x)))
        x = self.norm2(x + self.dropout(ff_output))
        
        return x


# ---------------- 测试代码 ----------------
if __name__ == "__main__":
    # 初始化编码器单层
    encoder_layer = TransformerEncoderLayerDemo(d_model=512, num_heads=8)
    
    # 构造测试输入:(batch_size, seq_len, d_model)
    # batch_size=32:一次处理32个句子
    # seq_len=100:每个句子有100个词(不足补全,超过截断)
    # d_model=512:每个词的嵌入维度是512
    test_input = torch.randn(32, 100, 512)
    
    # 前向传播
    test_output = encoder_layer(test_input)
    
    # 输出维度应该和输入一致!
    print(f"输入维度: {test_input.shape}")    # torch.Size([32, 100, 512])
    print(f"输出维度: {test_output.shape}")  # torch.Size([32, 100, 512])

4. Parameters of bullish attention: no increase in total quantity!

Many people are worried about "Will the parameter volume explode if it is split into H pieces?" In fact, not at all - the total parameters of multi-head attention are the same as single-head attention!

Let’s use the configuration of the NLP introductory classic model BERT-base to verify:

  • BERT-base core parameters:d_model=768num_heads=12,Sohead_dim = 768/12 = 64

4.1 Parameters of single-head attention

A single head requires 4 projection matrices:

  • The projection matrices of queries, keys, and values ​​ared_model × d_model(Because there is no segmentation, the entire dimension is transformed together) -The final integration matrix is ​​alsod_model × d_model
  • Total number of parameters = 3 × (768 × 768) + 768 × 768 = 4 × 768 × 768 ≈ 2.36M

4.2 Parameters of multi-head attention

After splitting into 12 heads:

  • The query, key, and value projection matrix for each header becomesd_model × head_dim(Only responsible for some dimensions)
  • 12 heads use a total of 3 × 12 × (768 × 64) = 3 × 768 × (12 × 64) = 3 × 768 × 768 -The final integration matrix is ​​stilld_model × d_model = 768 × 768
  • Total parameter amount = Also 4 × 768 × 768 ≈ 2.36M!

The conclusion is clear: Multi-head attention only changes the "organizational structure" of calculations and splits the parameters into different specialized heads, but the overall scale does not change at all - it belongs to "spending the same money to buy more comprehensive services".


5. Reference for multi-head configuration of classic model

In actual development, we don't need to make up the configuration ourselves - following the experience values ​​of mainstream large models usually works best.

The following table summarizes the multi-head related core configurations of several entry-level and advanced commonly used models:

Model named_model (embedded dimension)num_heads (number of heads)d_ff (feedforward intermediate dimension)dropout (dropout rate)
BERT-base7681230720.1
BERT-large10241640960.1
GPT-2 small7681230720.1
GPT-2 medium10241640960.1
Llama 2 7B409632110080.0

5.1 Summary of configuration experience

  1. Relationship between number of heads and embedded dimensions: usually followshead_dim = 64(This is the empirical optimal value in the original paper of Transformer), sonum_heads = d_model / 64. for exampled_model=768→12 heads,d_model=1024→16 heads,d_model=4096→32 heads, exactly the same as the classic configuration above;
  2. Feedforward intermediate dimension: usuallyd_modelAbout 4 times (BERT-base is 3072=4×768, BERT-large is 4096=4×1024, Llama 2 7B has a special design, but it is also around 3-4 times);
  3. Discard rate: Usually around 0.1 is used in the fine-tuning stage, and may drop to 0 in the later stages of pre-training large models.

6. Quick summary

Let’s summarize the core of bull’s attention in 3 sentences:

  1. Problem solved: Single-head attention cannot capture multiple relationships of language at the same time;
  2. Core idea: Split the embedding dimension into H parts on average, each head uses exclusive parameters to learn different semantic patterns, and finally splice and integrate;
  3. Core Advantage: The number of parameters of flower and single head is the same, and a more comprehensive and accurate semantic representation is obtained.

💡 Practical development tips: If your video memory is limited, you can appropriately reduce the embedding dimensiond_model, while maintaininghead_dim=64**——This can reduce the amount of parameters and calculation at the same time, but try not to change it casuallyhead_dim, because 64 is the empirical optimal value that has been extensively verified.


🔗Extended learning resources