Self-Attention Self-attention calculation: the mathematical essence of Q, K, V matrices

📂 Stage: Stage 3 — Transformer Revolution (Core) 🔗 Pre-association: Seq2Seq 标准注意力入门 · 词嵌入与位置编码基础 🧱 Subsequent modules: 多头注意力 Multi-Head Attention · Transformer 编码器-解码器


1. Getting started: What is the difference between Self-Attention and standard Seq2Seq attention?

The core idea of ​​the attention mechanism is just three words - "pick the key points". Whether it is reading a book, looking at pictures or processing sequences, we all hope that the model can focus its "eyes" on the most critical places. But the same is "picking the key points", Self-Attention and the attention in the classic Seq2Seq, the selection range is completely different.

1.1 One table understands two kinds of attention

DimensionsStandard Seq2Seq AttentionSelf-Attention (the heart of Transformer)
The source of Q/K/VQ only comes from the current step of the decoder, K and V all come from the encoderQ, K, V all come from the same input sequence
What problem to solveLet the decoder "refer to the information of the encoder" to generate the next wordLet each word in the sequence re-recognize itself - fuse the information of other words in the entire sequence and update its own representation
Typical exampleWhen translating "cat → cat", the decoder's Q pays attention to the encoder's "cat"In the sentence "The animal didn't cross the street because it was too tired", it's Self-Attention will automatically put the highest weight on animal

One sentence summary: **Standard attention is "the decoder looks at the encoder", Self-Attention is "the sequence looks at itself". **

🧠 Small thoughts: It is precisely because Self-Attention works entirely on the same sequence that Transformer's encoder can calculate the updates of all words in parallel, while RNN must do it step by step, which is the key to the performance revolution.


2. Core: Physical meaning and complete calculation of Q, K, V matrices

The core tools of Self-Attention are three learnable projection matrices:W_qW_kW_v. They are like three different pairs of "glasses", allowing the same word vector to play three different roles.

2.1 Give Q/K/V a “human” version of the metaphor

Suppose you have a pile of sticky notes to be organized, each sticky note represents a word (vectorx), your task is to write a richer, more contextual version of each note.

  • 🕵️ Query (Q): I wrote on the note: "What kind of information do I need to find now?"
  • 🏷️ Key (K): I wrote on the note: "What core labels do I have in myself?"
  • 📦 Value(V): I wrote on the note: "If someone chooses me, what specific content can I share with them?"

The entire Self-Attention operation process is like a "full-person matching conference":

  1. Write Q, K, V on all notes (viaW_qW_kW_vprojected).
  2. For note A, take its Q and do "similarity matching" (dot product) with K of all notes to get the matching score.
  3. Smooth the scores and convert them into weights between 0 and 1 (adding up to 1).
  4. Use these weights to weight the V of all notes to get a new representation of note A.

Thus, each word was re-released with the "collective wisdom" of the entire sequence.

2.2 Pure PyTorch implements single-head Self-Attention

The following code completely implements single-head Self-Attention, and every key calculation is marked with changes in tensor dimensions - for understanding Transformer, dimension is the lifeline.

import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        # 三个独立的可学习线性投影层,把输入映射到 Q/K/V 空间
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        # x 的形状:(batch_size, seq_len, embed_dim)
        # 既可以是词嵌入,也可以是上一层 Transformer 的输出

        # ── 步骤 1:投影得到 Q / K / V ──
        Q = self.W_q(x)  # (batch, seq_len, embed_dim)
        K = self.W_k(x)  # (batch, seq_len, embed_dim)
        V = self.W_v(x)  # (batch, seq_len, embed_dim)

        # ── 步骤 2-4:缩放点积注意力 ──
        output, attention_weights = self._scaled_dot_product(Q, K, V, mask)
        return output, attention_weights

    def _scaled_dot_product(self, Q, K, V, mask=None):
        d_k = Q.size(-1)          # 投影后的维度(默认就是 embed_dim)

        # ── 步骤 2:计算相似度(点积)并缩放 ──
        # 为什么缩放? 点积值如果太大,softmax 后会变得极端,梯度消失。
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, seq_len, seq_len)
        scores = scores / (d_k ** 0.5)                 # 除以 √d_k

        # ── 可选:遮盖掉不应关注的位置 ──
        # 解码器中用来屏蔽未来词,或忽略填充符 <pad>
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)  # 负无穷,softmax 后趋近于 0

        # ── 步骤 3:softmax 得到注意力权重 ──
        attention_weights = F.softmax(scores, dim=-1)  # 每行权重之和为 1

        # ── 步骤 4:用权重加权 V,得到输出 ──
        output = torch.matmul(attention_weights, V)     # (batch, seq_len, embed_dim)
        return output, attention_weights

💡 **Can’t understand dimensions? It doesn’t matter, just remember one rule: **

  • QandKThe dot product produces a "relationship matrix" of (seq_len × seq_len), with each row representing the attention score of that word to all words.
  • After normalization, it becomes a weight and then multiplied byV, obtain a new representation that incorporates global information.

3. Advanced: Multi-Head Attention

Single-head Self-Attention can already solve many problems, but it can only learn one "attention mode" at a time. Just like when you close one eye and look at the world, you can perceive distance, but you can't see the three-dimensional depth clearly.

3.1 Why multiple heads?

Multi-head attention is equivalent to wearing several different pairs of glasses at the same time, each pair of glasses focusing on different language features:

  • 🧐 Header 1: Responsible for capturing grammatical relationships (subject-predicate collocation)
  • 🧐 Header 2: Responsible for capturing semantic relationships (cat-meow)
  • 🧐 Header 3: Responsible for capturing the referential relationship (it → animal)
  • 🧐 Header 4: Responsible for capturing long-range dependencies (because...so...)

Each attention head has its own independent set of projection matrices (W_qW_kW_v), so completely different matching rules can be learned through training. Finally, the outputs of all heads are concatenated and linearly transformed to obtain a semantically richer word representation.

3.2 Pure PyTorch implements multi-head attention

Below is a multi-attention module that can be used directly. For generality, we support Q, K, V from different inputs (Encoder-Decoder Attention), or they all come from the same input (Self-Attention).

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads   # 每个头分配到的维度
        self.scale = self.head_dim ** -0.5       # 缩放因子,等同于 1/√d_k

        # Q/K/V 的联合投影(也可以拆开,效果一样)
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        # 多头输出拼接后的最终线性融合
        self.W_o = nn.Linear(embed_dim, embed_dim)

    def _split_into_heads(self, x, batch_size):
        # 将最后维度切分为 (num_heads × head_dim)
        # (batch, seq_len, embed_dim) -> (batch, seq_len, num_heads, head_dim)
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        # 转换为 (batch, num_heads, seq_len, head_dim),方便独立计算每个头
        return x.permute(0, 2, 1, 3)

    def forward(self, Q_input, K_input, V_input, mask=None):
        """
        灵活模式:
          - 编码器 Self-Attention: Q_input = K_input = V_input = x
          - 解码器 Self-Attention: 同上
          - 交叉注意力: Q来自解码器,K/V 来自编码器
        """
        batch_size = Q_input.size(0)

        # ── 1. 线性投影 ──
        Q = self.W_q(Q_input)   # (batch, seq_len_Q, embed_dim)
        K = self.W_k(K_input)   # (batch, seq_len_KV, embed_dim)
        V = self.W_v(V_input)   # (batch, seq_len_KV, embed_dim)

        # ── 2. 切分为多头 ──
        Q = self._split_into_heads(Q, batch_size)  # (batch, heads, seq_len_Q, head_dim)
        K = self._split_into_heads(K, batch_size)  # (batch, heads, seq_len_KV, head_dim)
        V = self._split_into_heads(V, batch_size)  # (batch, heads, seq_len_KV, head_dim)

        # ── 3. 每个头独立执行缩放点积注意力 ──
        # 计算注意力分数 (batch, heads, seq_len_Q, seq_len_KV)
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)        # 权重
        head_outputs = torch.matmul(attention_weights, V)    # (batch, heads, seq_len_Q, head_dim)

        # ── 4. 合并多头并做最终的线性融合 ──
        # (batch, heads, seq_len_Q, head_dim) -> (batch, seq_len_Q, heads, head_dim)
        head_outputs = head_outputs.permute(0, 2, 1, 3).contiguous()
        # 拼接所有头 -> (batch, seq_len_Q, embed_dim)
        concatenated = head_outputs.view(batch_size, -1, self.num_heads * self.head_dim)
        # 最后再线性变换一次,让不同头学到的特征「相互配合」
        final_output = self.W_o(concatenated)

        return final_output, attention_weights

Usage suggestions:

  • When building the Transformer layer,MultiHeadAttentionIt’s the core building block.
  • PyTorch is also officially availabletorch.nn.MultiheadAttention, but its input dimension order is (seq_len, batch, embed_dim), which is different from our habit (batch, seq_len, embed_dim). Remember to add when usingbatch_first=True

4. Summary: One picture flow + two major advantages

4.1 Single-head Self-Attention minimalist process

flowchart LR
    A[输入序列<br/>词嵌入/上一层输出] --> B["投影 W_q / W_k / W_v<br/>得到 Q / K / V"]
    B --> C["Q·K^T + 缩放<br/>相似度矩阵"]
    C --> D["softmax 归一化<br/>注意力权重"]
    D --> E["权重 · V<br/>得到全局融合的输出"]

Each time this process is calculated, each word in the sequence "communicates" with all words.

4.2 Why is Self-Attention so strong?

  1. Long-distance dependence in one step For any two words in the sequence, no matter how far apart they are, their interaction path length is 1. In comparison, RNN requires O(n) steps and CNN requires O(log n) steps. Self-Attention directly captures the dependency of "beginning subject and ending predicate".

  2. Fully parallel, GPU friendly The Q/K/V projection, similarity calculation, and weight normalization of all words can be directly thrown into the GPU and calculated in parallel from a large tensor. This is the key to how Transformer trains quickly and scales easily to large models.


🔗High quality extended reading

📘 After mastering Q/K/V and multi-head attention, the next step is to build a complete Transformer encoder, so stay tuned!