Recurrent Neural Network (RNN): logic for processing sequence data
📂 Stage: Stage 2 - Deep Learning and Sequence Model (Advanced) 🔗 Related chapters: PyTorch 基础 · 长短时记忆网络 LSTM/GRU
1. Why do we need RNN?
1.1 Limitations of traditional neural networks
Traditional fully connected networks (Dense) and convolutional networks (CNN) essentially process each input sample independently. No matter how we arrange the words in the text, as long as the word frequencies are similar, the feature vectors output by the model may be very similar.
Text sequence: "This movie is ugly" vs "This movie is good". Traditional networks only count the words "movie", "this" and "part", and may even make the feature vectors of two sentences almost the same, completely unable to distinguish the semantic difference between "good-looking" and "ugly"!
This method of "only looking at word frequency, not order" seems to be inadequate when faced with sequence data with obvious order such as natural language, speech signals, and stock prices.
1.2 The core breakthrough of RNN: adding "memory"
The secret of Recurrent Neural Network (RNN for short) is hidden in its name - cyclically reusing the same neuron unit and introducing something called Hidden State to save "previously seen information". Simply put, when RNN reads a sequence, each step will fuse the current input with the "memory" left by the previous step to generate new memory and output.
For a more intuitive understanding, the RNN can be expanded along the time steps (each element in the sequence):
The processing logic of each time step is the same:
- Receive current input
x_tand the hidden state of the previous steph_{t-1} - Generate a new hidden state through the same RNN unit
h_tand current outputy_t
The internal work of the RNN unit is very simple: first multiply the current input and the previous hidden state by their respective weight matrices, add a bias, and finally compress the value to between (-1, 1) through the hyperbolic tangent (tanh) activation function to obtain the new hidden state. The whole process is exactly the same, no matter how long the sequence is, the same set of parameters is reused, so RNN can naturally handle sequences of any length.
2. The fatal problem of RNN
Although RNN solves the problem of "independent processing", it has two natural flaws, which leads to the fact that native RNN is basically no longer used in long sequence tasks**.
2.1 Gradient vanishing and exploding
The core of training a neural network is backpropagation: stepwise back from the output layer to calculate the impact of each parameter on the final loss, and then update the parameters. The special thing about RNN is that the parameters are reused at time steps, so during backpropagation, the gradient will be multiplied multiple times along the time step (the number of multiplications is equal to the length of the sequence).
- Vanishing gradient: If the gradient value of continuous multiplication is generally less than 1, then after many multiplications, the gradient will become smaller and smaller, approaching 0. This means that the model has little ability to learn the current impact of information from long ago.
- Gradient explosion: If the gradient value is greater than 1, after multiple multiplications, the gradient will increase exponentially and approach infinity, causing the loss value (Loss) during training to directly become
NaN, training crashes.
“I was born in China…(1000 completely unrelated words mixed in)…I can say ___” The native RNN will most likely forget "China" in the first clause, making it difficult to fill in "Chinese"!
Gradient explosion can be alleviated by Gradient Clipping: setting a threshold and forcibly scaling down the gradient when it exceeds the threshold. However, the vanishing gradient problem is almost impossible to fix for native RNN, which directly gave rise to improved models such as LSTM and GRU.
2.2 Other shortcomings
In addition to the gradient problem, native RNN also has some minor flaws:
- Serial calculation, limited efficiency: The calculation of each time step must wait for the completion of the previous step. It cannot be parallelized on a large scale like CNN, and the training speed is slow.
- Sensitive to initial states: The choice of the initial hidden state (usually an all-zero vector) affects the learning effect at early time steps.
3. PyTorch RNN rapid implementation
Although native RNN is not commonly used, it is the basis for understanding LSTM and GRU. Next, we use PyTorch to implement a simple text classifier and experience how to use RNN.
3.1 One-way RNN text classifier
The code below builds a unidirectional RNN for sentiment binary classification. The input is the token id sequence of the sentence, and the output is the logits of the positive/negative class.
3.2 Bidirectional RNN text classifier
Sometimes understanding a word requires not only looking at what was said before, but also the context behind it. For example, "I feel very __ today because I didn't eat hot pot." The emotion at the horizontal line is obviously negatively related to the following "I didn't eat hot pot."
Bidirectional RNN (Bidirectional RNN) trains structures in two directions at the same time:
- Forward RNN: Process the sequence from left to right;
- Backward RNN: Process sequences from right to left.
Finally, the hidden states at the last moments of the two directions are spliced together as a representation of the entire sequence.
4. Summary and quick review
4.1 Review of core knowledge points
- The role of RNN: Introducing "memory" into sequence data to solve the problem of traditional networks that "only look at word frequency and ignore position".
- Expand graph understanding: After expanding according to time steps, the same unit is reused at each time step, and the input is the current word and the previous step memory.
- Fatal problem:
- The gradient disappears, resulting in long-distance dependencies not being captured;
- Gradient explosion, leading to unstable training.
- Improvement direction: LSTM/GRU specifically solves gradient disappearance; gradient clipping specifically solves gradient explosion.
4.2 PyTorch RNN quick check
Native RNN performs very poorly in long sequence tasks (long text classification, machine translation, etc.). Please use LSTM or GRU directly in actual projects. In the next article, we will analyze LSTM in depth.
🔗 Extended reading
- Colah's Blog: Understanding LSTM Networks (a must-see introduction to classic LSTM)
- PyTorch RNN 官方文档
- 梯度裁剪的 PyTorch 实现

