Long short-term memory network (LSTM/GRU): solve the vanishing gradient and capture long-distance dependencies

📂 Stage: Stage 2 - Deep Learning and Sequence Model (Advanced) 🎯 Prerequisite knowledge: Recurrent Neural Network (RNN) basics 🔗 Related chapters: 循环神经网络 (RNN) · 序列到序列模型 (Seq2Seq)


1. The core idea of ​​LSTM: install an “information safe” for RNN

1.1 Pain point: "short-sightedness" of ordinary RNN

When you ask an ordinary RNN to read a long article, it will often "forget" the beginning of the article very quickly. For example, if you analyze a movie review - "The first 20 minutes of "The Beginning" is a bit slow, but the whole process is very high-energy, and the last 10 episodes can't be stopped at all." Ordinary RNN is likely to only remember the "can't stop" at the end. The negative signal of "the first 20 minutes is a bit slow" at the beginning almost disappears during backpropagation. This is a typical problem caused by vanishing gradient.

The design goal of LSTM is to solve this "amnesia". It introduces a Cell State that runs throughout the entire sequence, you can think of it as an information conveyor belt. The conveyor belt can stably carry long-term memory, and with three learnable "doors", you can decide:

  • What old information should be forgotten from the conveyor belt?
  • What new information should be written to the conveyor belt?
  • In the final output, what content on the conveyor belt should be taken out and used?

These three gates are like traffic lights for data flow, allowing the model to control the flow of information in extremely fine detail.

1.2 Disassemble the calculation process of LSTM

To facilitate understanding, we take the sentiment analysis task as an example and gradually track the processing of a comment: "The beginning of this movie is a bit boring, but the ending is so healing and tearful!"

Step 1: Forgetting Gate—Cleaning Historical Memory

The forgetting gate determines "which old information in the cell state should be discarded." For example, when reading the word "but the ending", the model needs to realize that the weight of the previous "a bit boring" should be reduced or even deleted. Specific method: change the hidden state of the previous momenth_prev(temporary memory of the previous step) and current inputx_now(word vectors) are put together, processed by a set of learnable weights, and then given a value ranging from[0,1]"switch function" between.

  • Output close to1means "completely reserved";
  • Output close to0means "can forget";
  • Intermediate values ​​indicate "partial retention".

Step 2: Input gate + candidate state - prepare new information

This stage determines "what new knowledge to add to the cell state" and is operated by two accessories:

  1. Input selection gate: Still splicingh_prevandx_now, using switch functions to select which new information is worth remembering.
  2. Candidate status: Same splicingh_prevandx_now, but instead use a numerical range within[-1,1]The activation function generates a "new content draft".
  3. Multiply the two - only the content "lit" by the input gate is actually written to the conveyor belt.

Step 3: Update cell status - refresh conveyor belt

This is the core calculation of LSTM:

  1. Multiply the output of the forgetting gate by the old cell state (the content on the conveyor belt at the last moment);
  2. Add the writing result of the input gate;
  3. Get the updated cell status.

In this way, unimportant old information is forgotten, fresh and important information is written, and the conveyor belt always carries the most critical global memory at the moment.

Step 4: Output gate + hidden state - decide what to output to the next layer

Finally, the model decides which information to pick from the conveyor belt to generate the current output (hidden state):

  1. Output selection gate: Same splicingh_prevandx_now, select the exposed part of the conveyor belt through the switch function.
  2. Normalized conveyor belt content: pass the cell state through[-1,1]Compress the activation function to avoid excessively large values.
  3. Multiply the two to get the hidden state at the current moment - it contains both the most important information at this moment and content with long-distance dependencies, and will be passed to the next moment or subsequent fully connected classification layer.

1.3 PyTorch LSTM Practical Combat: Bidirectional Sentiment Classifier

The following implements a complete bidirectional LSTM text classification model. Bidirectional LSTM can scan sequences from left to right and right to left at the same time, which is better for tasks such as sentiment analysis that require global understanding.

import torch
import torch.nn as nn

class BiLSTMTextClassifier(nn.Module):
    """
    双向LSTM文本分类器
    适用于情感分析、新闻分类等短文本任务
    """
    def __init__(
        self,
        vocab_size: int,    # 词表大小
        embed_dim: int = 256,  # 词嵌入维度
        hidden_dim: int = 256, # LSTM隐藏层维度
        num_layers: int = 2,   # LSTM层数
        dropout: float = 0.3,  # dropout比例,防止过拟合
        num_classes: int = 2    # 分类类别数(二分类:积极/消极)
    ):
        super().__init__()
        
        # 1. 词嵌入层:把词ID转换成低维稠密向量
        self.embedding = nn.Embedding(
            vocab_size, embed_dim, padding_idx=0  # padding_idx=0:忽略词表中的填充词
        )
        
        # 2. 双向LSTM层
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,   # 输入输出的第一维度是batch_size(更符合习惯)
            bidirectional=True, # 双向:前向看前文,后向看后文
            dropout=dropout if num_layers > 1 else 0  # 只有多层LSTM才加层间dropout
        )
        
        # 3. 全连接分类头:双向拼接后的维度是 hidden_dim*2
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        input_ids: (batch_size, seq_len) → 每个样本是词ID序列
        logits: (batch_size, num_classes) → 每个样本对应类别的未归一化分数
        """
        # 词嵌入:(B, L) → (B, L, E)
        embedded = self.embedding(input_ids)
        
        # LSTM计算
        # output: (B, L, 2H) → 每个时刻的双向隐藏状态拼接
        # (h_n, c_n): 最后时刻的隐藏状态和细胞状态,(2*num_layers, B, H)
        output, (h_n, _) = self.lstm(embedded)
        
        # 取最后一层的双向隐藏状态拼接 → (B, 2H)
        last_layer_forward = h_n[-2]  # 前向最后一层的最后一个隐藏状态
        last_layer_backward = h_n[-1] # 后向最后一层的最后一个隐藏状态
        final_hidden = torch.cat([last_layer_forward, last_layer_backward], dim=-1)
        
        # 分类头计算
        logits = self.classifier(final_hidden)
        return logits

2. GRU: "Lightweight Lite Version" of LSTM

2.1 Improvement ideas of GRU

In 2014, Cho et al. proposed the Gated Recurrent Unit (GRU), which achieved similar effects to LSTM with a more concise structure. GRU merges the three gates of LSTM into two and removes the independent cell state - it uses a clever way to integrate "long-term memory" and "short-term temporary memory" into a unified hidden state.

FeaturesLSTMGRU
Number of gatesForget gate, input gate, output gateReset gate, update gate (two)
State structureCell state + hidden stateSingle hidden state
Number of parametersLargerAbout 30% less
Performance and efficiencyMore stable on complex tasksSimilar effects on small and medium-sized tasks, faster training/inference

2.2 PyTorch GRU actual combat: same task, lighter choice

Replacing the above LSTM model with GRU is very simple, just changenn.LSTMReplace withnn.GRU, also note that GRU does not return cell statusc_nThat’s it.

import torch
import torch.nn as nn

class BiGRUTextClassifier(nn.Module):
    """
    双向GRU文本分类器
    轻量高效,适合快速原型验证或简单任务
    """
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 256,
        hidden_dim: int = 256,
        num_layers: int = 2,
        dropout: float = 0.3,
        num_classes: int = 2
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # 替换成nn.GRU
        self.gru = nn.GRU(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # 分类头不变
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        embedded = self.embedding(input_ids)
        # GRU只返回output和h_n,没有c_n
        _, h_n = self.gru(embedded)
        final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=-1)
        logits = self.classifier(final_hidden)
        return logits

3. Practical clip: quickly run through emotion classification training

3.1 Training and verification of single-round functions

In order to make model training more stable, we usually use gradient clipping to prevent gradient explosion and calculate accuracy simultaneously.

import torch
from torch.utils.data import DataLoader

def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    clip_max_norm: float = 1.0  # 梯度裁剪的最大范数
) -> tuple[float, float]:
    """
    训练单轮模型
    返回:平均损失、平均准确率
    """
    model.train()
    total_loss = 0.0
    correct_preds = 0
    total_preds = 0

    for batch in dataloader:
        # 把数据移到GPU/CPU
        input_ids = batch["input_ids"].to(device)
        labels = batch["label"].to(device)

        # 前向传播
        optimizer.zero_grad()
        logits = model(input_ids)
        loss = criterion(logits, labels)

        # 反向传播 + 梯度裁剪 + 参数更新
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
        optimizer.step()

        # 统计指标
        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct_preds += (preds == labels).sum().item()
        total_preds += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    avg_acc = correct_preds / total_preds
    return avg_loss, avg_acc

3.2 Minimalist inference function

During inference, you only need to load the trained model, turn off gradient calculation and callsoftmaxYou can get the probability of each category.

import torch

def predict_sentiment(
    model: nn.Module,
    tokenizer,  # 假设已初始化好的Tokenizer
    text: str,
    device: torch.device
) -> dict[str, float]:
    """
    预测单条文本的情感
    返回:积极、消极的概率字典
    """
    model.eval()
    with torch.no_grad():  # 推理时不需要计算梯度
        # 分词+转ID+补填充(这里简化,实际用tokenizer的__call__更方便)
        tokens = tokenizer.tokenize(text)
        ids = tokenizer.convert_tokens_to_ids(tokens)
        # 加batch维度 → (1, seq_len)
        input_ids = torch.tensor([ids]).to(device)
        
        # 前向传播
        logits = model(input_ids)
        # 用softmax把未归一化分数转成概率
        probs = torch.softmax(logits, dim=-1).squeeze(0)  # 去掉batch维度
    
    return {
        "positive": round(probs[1].item(), 4),
        "negative": round(probs[0].item(), 4)
    }

4. Selection suggestions for 2026

4.1 LSTM vs GRU simple comparison

DimensionsLSTMGRU
Number of parametersMoreAbout 30% less
Training/inference speedRelatively slowFaster
Long-range memory abilitySlightly stronger in theorySufficient for medium/simple tasks
Historical typical applicationsMachine translation, speech recognitionSentiment analysis, text classification

4.2 Practical application in 2026

I must be honest: Although LSTM/GRU is a "required course" for every deep learning practitioner to get started with sequence modeling, in the current NLP and speech fields, pre-trained models (BERT, GPT, Whisper, etc.) based on the Transformer architecture have basically occupied the mainstream. These models obtain powerful universal representations through massive unsupervised pre-training, and with fine downstream fine-tuning, their performance far exceeds that of LSTM/GRU trained from scratch. Moreover, with the in-depth optimization of the self-attention mechanism in hardware such as A100 and H100, the training efficiency of Transformer can even be as good as that of multi-layer LSTM.

4.3 When will LSTM/GRU be used again?

Nonetheless, LSTM and GRU are still active in the following scenarios:

  1. Edge devices and low-latency scenarios: small parameter scale, fast inference speed, and low hardware requirements.
  2. Small tasks with strong sequential nature: such as detection of timing anomalies in certain sensors, low-resource part-of-speech tagging of niche languages, etc.
  3. Academic Baseline Comparison: When doing research, LSTM/GRU is one of the most classic comparison models and an important reference for measuring the effectiveness of new methods.

🔗 Extended reading