title: Detailed explanation of CRNN: end-to-end variable-length text recognition model | Daoman PythonAI description: In-depth analysis of the CRNN (Convolutional Recurrent Neural Network) model and introduction to its application in the OCR field, including detailed architecture analysis, PyTorch implementation and practical application scenarios. keywords: [CRNN, OCR, optical character recognition, text recognition, deep learning, computer vision, PyTorch, sequence recognition]

Detailed explanation of CRNN: End-to-end variable-length text recognition model

When you use your mobile phone to scan courier orders and automatically fill in addresses, use parking cameras to read license plates in seconds, and convert PDF to Word to extract plain text, there is a high probability that there is an efficient variable-length text recognition engine behind it - and CRNN is the originator of this type of engine and one of the cornerstones of industrial applications.


Introduction

In the early days of optical character recognition (OCR), "first segment single characters and then classify" was the mainstream idea, but this solution had fatal flaws:

  • Relies on complex character segmentation algorithms and cannot handle glued characters, fuzzy deformed characters, and natural scene tilted characters
  • The annotation cost is extremely high, requiring manual selection of each character.
  • Unable to handle text sequences with non-uniform length and ambiguous sentence fragments

In 2015, CRNN (Convolutional Recurrent Neural Network) proposed by Baoguang Shi et al. completely broke this pattern. Through the three-stage architecture of "CNN feature extraction → BiLSTM sequence modeling → CTC alignment decoding", it achieves complete end-to-end variable-length text sequence recognition for the first time, without any character-level segmentation and annotation.


1. Overview of CRNN model

1.1 Core three-stage logic

The design philosophy of CRNN is very clear: understand the image as a "time sequence" from left to right, and each column of pixels is a time step. The specific process is as follows:

flowchart LR
    A[输入图像<br/>32×W×1/RGB] --> B[CNN特征提取<br/>高维度→序列化特征图<br/>(1×W_seq×512)]
    B --> C[BiLSTM序列建模<br/>捕获上下文依赖<br/>(W_seq×Batch×Hidden*2)]
    C --> D[线性分类层<br/>(W_seq×Batch×nclass)]
    D --> E[CTC贪婪解码<br/>得到最终文本]

One sentence summary: **Convert CNN from a tool of "image classification/detection" to a feature producer of "feeding visual time steps to the sequence model", then use BiLSTM to complete the language/structural association between characters, and finally rely on CTC to solve the problem of length mismatch between the output and the label. **

1.2 Core industrial-grade advantages

  • End-to-end training: only requires the paired data of "image → text"
  • Handle any aspect ratio: The height is fixed at 32, and the width can be infinitely scalable (as long as the input sequence length ≥ the target text length)
  • Lightweight and efficient: The inference speed is 3-10 times that of the Transformer-based model, suitable for edge device deployment
  • Less dependence: No need for dictionary assistance (a dictionary can improve it, but it is not necessary)
  • Strong interpretability: Each time step corresponds to a column of pixels on the image, making it easy to debug errors

2. PyTorch minimalist implementation

To help you get started quickly, here is a cropped and optimized version of VGG+double-layer BiLSTM+standard CTC compatible PyTorch implementation. The code is only about 200 lines, and it is fully trainable and inferable.

import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------------------------
# 1. CNN特征提取器:裁剪VGG,专门适配文本识别
# ---------------------------------------------------------
class VGGTextBackbone(nn.Module):
    """
    输入高度必须固定为32,宽度可变;输出高压缩为1,宽压缩为≈W/4±1
    """
    def __init__(self, in_channels=1, out_channels=512):
        super().__init__()
        self.backbone = nn.Sequential(
            # Block 1: 32×W → 16×W/2
            nn.Conv2d(in_channels, 64, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            
            # Block 2: 16×W/2 → 8×W/4
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            
            # Block 3: 8×W/4 → 4×(W/4+2)(宽stride=1,保留序列长度)
            nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,2), (2,1), (0,1)),
            
            # Block 4: 4×(W/4+2) → 1×(W/4+1)(高压缩为1)
            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,2), (2,1), (0,0)),
            
            # Block 5: 进一步降维,宽再减1(≈W/4)
            nn.Conv2d(512, out_channels, 2, 1, 0), nn.ReLU(True)
        )

    def forward(self, x):
        return self.backbone(x)

# ---------------------------------------------------------
# 2. 单层双向LSTM:自带降维线性层
# ---------------------------------------------------------
class BiLSTMEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.lstm = nn.LSTM(in_dim, hidden_dim, bidirectional=True, batch_first=False)
        self.linear = nn.Linear(hidden_dim * 2, out_dim) # 双向输出拼接后降维

    def forward(self, x):
        # x shape: (Time_steps, Batch, in_dim)
        lstm_out, _ = self.lstm(x) # lstm_out: (Time_steps, Batch, hidden*2)
        
        # 时间步+批量 展平给线性层,恢复形状
        T, B, H2 = lstm_out.shape
        return self.linear(lstm_out.reshape(T*B, H2)).reshape(T, B, -1)

# ---------------------------------------------------------
# 3. 完整CRNN模型
# ---------------------------------------------------------
class CRNN(nn.Module):
    def __init__(self, img_h=32, in_channels=1, nclass=27, hidden_dim=256):
        """
        img_h: 固定为32(否则VGG下采样后高度不为1)
        nclass: 必须包含【空白符(0) + 目标字符集(1-N)】
        """
        super().__init__()
        assert img_h % 16 == 0, "img_h必须是16的倍数"
        self.backbone = VGGTextBackbone(in_channels)
        self.rnn1 = BiLSTMEncoder(512, hidden_dim, hidden_dim)
        self.rnn2 = BiLSTMEncoder(hidden_dim, hidden_dim, nclass)

    def forward(self, x):
        # Step 1: CNN提取特征
        conv = self.backbone(x) # (B, 512, 1, W_seq)
        B, C, H, W_seq = conv.shape
        assert H == 1, "CNN输出高度必须为1"
        
        # Step 2: 特征转序列(关键!)
        conv = conv.squeeze(2)    # (B, C, W_seq)
        conv = conv.permute(2, 0, 1) # (W_seq, B, C) → 符合PyTorch RNN输入格式
        
        # Step 3: 序列预测
        rnn_out = self.rnn1(conv)
        return self.rnn2(rnn_out)

# ---------------------------------------------------------
# 4. 模型测试
# ---------------------------------------------------------
if __name__ == "__main__":
    # 假设识别小写英文a-z,加上空白符共27类
    model = CRNN(nclass=27)
    dummy_img = torch.randn(1, 1, 32, 100) # 1张32×100的灰度图
    
    with torch.no_grad():
        output = model(dummy_img)
    
    print(f"输入形状: {dummy_img.shape}")
    print(f"输出形状: {output.shape}") # 应该是 (24, 1, 27),24是时间步长

3. Quick Guide to Training and Inference

3.1 Training (CTC Loss usage details)

PyTorch built-innn.CTCLossFully compatible with CRNN output, but pay attention to the following parameters:

import torch.optim as optim

# 1. 初始化
model = CRNN(nclass=27)
criterion = nn.CTCLoss(blank=0, reduction='mean') # blank索引固定为0
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 2. 假设一次迭代的样本
# images: (Batch, 1, 32, W)
# targets: 是【所有标签的一维拼接】,例如['abc', 'de']→[1,2,3,4,5]
# target_lengths: [3, 2]
# input_lengths: 每个样本的时间步长(所有样本时间步长相同的话用torch.full)
images = torch.randn(2, 1, 32, 100)
targets = torch.tensor([1,2,3,4,5], dtype=torch.long)
target_lengths = torch.tensor([3,2], dtype=torch.long)
input_lengths = torch.full((2,), 24, dtype=torch.long)

# 3. 前向传播+计算损失
model_output = model(images)
log_probs = F.log_softmax(model_output, dim=2) # CTC必须用对数概率
loss = criterion(log_probs, targets, input_lengths, target_lengths)

# 4. 反向传播+更新
optimizer.zero_grad()
loss.backward()
optimizer.step()

3.2 Reasoning (greedy decoding implementation)

The simplest decoding method, no dictionary required, suitable for quick verification:

def ctc_greedy_decode(output_probs, idx2char, blank_idx=0):
    """
    output_probs: (Time_steps, nclass) → 推理时取单样本的输出
    idx2char: 索引→字符的字典,如{1:'a', 2:'b', ...}
    """
    # 1. 每个时间步取概率最大的索引
    pred_indices = output_probs.argmax(dim=1).cpu().numpy()
    
    # 2. 合并连续重复的非blank,去除所有blank
    decoded = []
    prev_idx = blank_idx
    for idx in pred_indices:
        if idx != blank_idx and idx != prev_idx:
            decoded.append(idx)
        prev_idx = idx
    
    # 3. 转文本
    return ''.join([idx2char[i] for i in decoded])

4. Practical suggestions

4.1 Data processing

  • The height of the input image must be fixed at 32**, the width is scaled according to the original image ratio, and the long side does not exceed 256/512 (adjusted according to the video memory)
  • Grayscale images are usually better than RGB (unless the character color has a strong color dependence on the background)
  • Data enhancement: Random slight tilt (-15°~15°), random stretching (width 0.9-1.1), adding Gaussian noise/blur, contrast adjustment, these 4 types have the greatest improvement to CRNN

4.2 Model deployment

  • Edge devices (mobile phones/camera): usetorch.onnxConvert to ONNX and use againONNX Runtime-TensorRT/NCNN/TNNAcceleration, inference speed can reach 100fps+
  • Cloud/server: Just use PyTorch inference or TensorRT acceleration

Summarize

CRNN is a milestone model in the field of OCR from "traditional segmentation" to "end-to-end recognition". Although Transformer-based models (such as CRNN-Transformer, PARSeq, and MASTER) currently dominate in terms of accuracy, CRNN's lightweight, efficient, and less-dependent features are still the first choice for standardized scenarios such as license plate recognition, bill recognition, and document line recognition.

It is recommended to master the implementation code of this article first, and then try toSynthText/IIIT5K/Train on your own data set, and finally compare the effects of different models!


1. CTC Loss`blank`The index **must be placed first in the character set** 2. The time step of CNN output** must ≥ the maximum length of the target text** 3. Be sure to do this before reasoning`log_softmax`? No, greedy decoding directly uses`argmax`That’s fine, but training must use logarithmic probability

🔗 Related Resources