Model lightweighting: Detailed explanation of MobileNet, quantification, pruning and edge deployment

Introduction

Model lightweighting is the last mile for the implementation of deep learning. It allows us to greatly compress the model's parameter volume, calculation volume, and memory footprint without losing almost any accuracy, allowing powerful AI models to run smoothly on edge devices with limited resources such as mobile phones, smart cameras, and IoT chips. This tutorial will help you systematically master the lightweight architecture, model quantification, network pruning and deployment practices of the MobileNet series, covering the complete link from principles to code.

📂 Stage: Stage 2 - Deep Learning Vision Basics (CNN) 🔗 Related chapters: 3D 视觉基础 · 推理加速框架


1. Model lightweight core evaluation

1.1 Why lightweight?

Scenario pain pointsThe value of lightweight solutions
Insufficient computing power/memory of edge deviceReduce hardware configuration requirements
High real-time inference latencyImprove inference speed (FPS)
Mobile terminal power consumption/bandwidth is limitedExtend battery life + reduce cloud interaction (protect privacy)
High deployment costsReduce procurement expenditures for cloud services/dedicated AI chips

1.2 Key evaluation indicators

  • Parameter quantity (Params): The total number of model weights, which directly affects the model storage size (Float32=4 bytes/parameter)
  • Computational amount (FLOPs): The number of floating point operations during inference, which determines the upper limit of hardware utilization
  • Inference latency: The time taken for a single complete inference (millisecond level is the basis for edge applications)
  • Memory Peak: The maximum memory occupied during inference (video memory/memory)
  • Accuracy retention rate: The ratio of the Top1/Top5 accuracy rate after lightweighting to the original model

2. Lightweight network architecture: optimization from the design source

2.1 Core unit: Depthwise separable convolution

Depthwise separable convolutions are the cornerstone of the MobileNet family. It breaks down ordinary convolution into two steps:

  1. Depthwise convolution (Depthwise): Each input channel is processed separately with a 3×3 convolution, channels are not fused
  2. Pointwise convolution: Use 1×1 convolution to fuse the output of all channels, adjust the number of channels
import torch
import torch.nn as nn
import torch.nn.functional as F

class DepthwiseSeparableConv(nn.Module):
    """
    深度可分离卷积 + BN + ReLU
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        # 深度卷积:groups=in_channels,每个通道独立卷积
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, 
            kernel_size=kernel_size, stride=stride, 
            padding=padding, groups=in_channels, bias=False
        )
        self.bn1 = nn.BatchNorm2d(in_channels)
        # 点卷积:1×1 卷积融合通道
        self.pointwise = nn.Conv2d(
            in_channels, out_channels, 
            kernel_size=1, stride=1, padding=0, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.relu(self.bn1(self.depthwise(x)))
        x = self.relu(self.bn2(self.pointwise(x)))
        return x

💡 Comparison of calculation amount: Assume that the input feature map size isH × W, input channelC_in, output channelC_out
FLOPs for a regular 3×3 convolution are approx.9 × C_in × C_out × H × W
FLOPs of depthwise separable convolution are approx.(9 × C_in + C_in × C_out) × H × W
When the number of channels is large, the calculation amount is reduced by about 8~9 times, which greatly reduces the computing consumption.


2.2 MobileNetV1/V2 core implementation

MobileNetV1

MobileNetV1 uses width multiplierwidth_multiplierScale the number of channels in each layer proportionally to flexibly choose between accuracy and speed. The following code implements a complete V1 network that adjusts width as needed.

def make_divisible(v, divisor=8, min_value=None):
    """确保通道数是8的倍数(硬件友好)"""
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class MobileNetV1(nn.Module):
    def __init__(self, num_classes=1000, width_multiplier=1.0):
        super().__init__()
        input_channel = make_divisible(32 * width_multiplier)
        # 第一层:普通 3×3 降采样
        self.stem = nn.Sequential(
            nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(inplace=True)
        )
        # V1配置:[out_channels, repeats, stride]
        config = [[64,1,1],[128,2,2],[256,2,2],[512,6,2],[1024,2,2]]
        layers = []
        for c, n, s in config:
            c = make_divisible(c * width_multiplier)
            layers.append(DepthwiseSeparableConv(input_channel, c, stride=s))
            input_channel = c
            layers.extend([DepthwiseSeparableConv(input_channel, c, stride=1) for _ in range(n-1)])
        self.features = nn.Sequential(*layers)
        # 分类头
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.classifier = nn.Linear(input_channel, num_classes)
    
    def forward(self, x):
        x = self.stem(x)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)

Improvements in MobileNetV2: Inverted Residual + Linear Bottleneck

MobileNetV2 has made two key optimizations to address the shortcomings of V1:

  1. Inverted residual structure: First 1×1 convolution expansion channel (allowing deep convolution to extract rich features in high dimensions), then perform depth convolution, and finally use 1×1 convolution compression channel to form an "expansion-convolution-compression" hourglass shape.
  2. Linear bottleneck: ReLU activation is no longer used after the compression layer to avoid destruction of low-dimensional spatial information.
class InvertedResidual(nn.Module):
    """MobileNetV2 倒残差块"""
    def __init__(self, in_channels, out_channels, stride, expand_ratio=6):
        super().__init__()
        self.stride = stride
        hidden_dim = int(in_channels * expand_ratio)
        self.use_res = self.stride == 1 and in_channels == out_channels
        
        layers = []
        # 扩展层(expand_ratio ≠ 1 时才加入)
        if expand_ratio != 1:
            layers.extend([
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True)
            ])
        # 深度卷积
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU6(inplace=True)
        ])
        # 压缩层(线性瓶颈,不用 ReLU)
        layers.extend([
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        ])
        self.conv = nn.Sequential(*layers)
    
    def forward(self, x):
        return x + self.conv(x) if self.use_res else self.conv(x)

3. Model quantization: reduce numerical accuracy

Quantization refers to converting the floating point weights and activation values ​​of the model into low-precision integers (such as Int8), thereby achieving 4x volume compression and 2~3x inference acceleration (the effect is particularly obvious on hardware that supports Int8 acceleration).

3.1 PyTorch static quantization (most commonly used for deployment)

The standard process of static quantization: Train the Float32 model → Fusion BN + Conv → Calibrate quantization parameters → Convert to Int8 model.

import torch.quantization as quant

class QuantizableMobileNetV2(MobileNetV2):
    """添加量化桩的可量化 V2(需继承上文实现的 MobileNetV2)"""
    def __init__(self, num_classes=1000, width_mult=1.0):
        super().__init__(num_classes, width_mult)
        self.quant = quant.QuantStub()      # 输入量化
        self.dequant = quant.DeQuantStub()  # 输出反量化
    
    def forward(self, x):
        x = self.quant(x)
        x = super().forward(x)
        return self.dequant(x)
    
    def fuse_model(self):
        """融合 Conv + BN,减少计算开销并提高量化精度"""
        for m in self.modules():
            if isinstance(m, InvertedResidual):
                # 根据 expand_ratio 决定融合哪些层
                if len(m.conv) == 9:  # expand_ratio ≠ 1 时的结构
                    quant.fuse_modules(m.conv, [['0','1'], ['3','4']], inplace=True)
                else:
                    quant.fuse_modules(m.conv, [['0','1']], inplace=True)

# 静态量化完整流程(简化版)
def static_quantize(model, calib_loader):
    model.eval()
    model.fuse_model()
    # 设置量化配置:ARM 用 'qnnpack',x86 用 'fbgemm'
    model.qconfig = quant.get_default_qconfig('qnnpack')
    # 准备量化(插入观察器)
    quant.prepare(model, inplace=True)
    # 校准:用少量数据统计激活值的量化范围
    with torch.no_grad():
        for data, _ in calib_loader:
            model(data)
    # 转换为 Int8 模型
    quant.convert(model, inplace=True)
    return model

4. Model pruning: remove redundant connections

Pruning significantly reduces the model size by removing unimportant weights or channels. PyTorch provides out-of-the-boxtorch.nn.utils.prunetool.

4.1 Commonly used pruning methods

  • Unstructured pruning: removes individual weights with the smallest absolute value (L1 norm), changes the sparse distribution but does not change the network structure, and requires a sparse computing library to effectively accelerate.
  • Structural Pruning: Directly remove the entire channel (such as sorting by L2 norm) and change the shape of the network, which can directly benefit general-purpose hardware.
import torch.nn.utils.prune as prune

def pruning_demo(model):
    # 1. 非结构化剪枝:移除 L1 范数最小的 20% 权重
    prune.l1_unstructured(model.features[0].conv[0], name='weight', amount=0.2)
    # 2. 结构化剪枝:移除 L2 范数最小的 16 个输出通道
    prune.ln_structured(model.features[0].conv[0], name='weight', amount=16, n=2, dim=0)
    # 3. 永久删除剪枝重参数化(真正释放资源)
    prune.remove(model.features[0].conv[0], 'weight')
    return model

⚠️ Usually a small amount of Fine-tuning is required to restore accuracy after pruning. Structured pruning is more suitable for general scenarios without customized hardware.


5. Deployment recommendations

5.1 Lightweight technology combination strategy

The industry often uses the sequence of "architecture design → pruning → quantification → fine-tuning" to gradually compress the model:

  1. Selection: Use MobileNetV2 / V3-Large / V3-Small, etc. as lightweight baselines
  2. Pruning: Use structured pruning (such as channel pruning) to remove redundancy
  3. Quantization: Apply static quantization or quantization-aware training to further reduce accuracy
  4. Fine-tuning: Retrain with a small amount of data to restore slight accuracy loss

5.2 Common deployment frameworks

FrameworkAdapted HardwareAdvantages and Features
PyTorch MobileAndroid/iOS/ARM LinuxPyTorch native support, API friendly
TensorFlow LiteFull platform mobile/embeddedMature ecosystem, complete tool chain
ONNX RuntimeCross-hardware (CPU/GPU/NPU)Supports multi-frame conversion, fast inference speed
The core of model lightweighting is "trade-off" - first clarify the specific requirements of the application scenario for accuracy, latency, and storage, and then choose the most appropriate technology combination. It is recommended to start with **MobileNetV2 + static quantization**, and then explore advanced techniques such as pruning and knowledge distillation after becoming familiar with it.

Summarize

Model lightweighting is a key link in the implementation of deep learning, which allows AI to move from the cloud to various edge devices such as mobile phones, cameras, and drones. By mastering MobileNet series architecture design, model quantification, network pruning and other technologies, you can create efficient, accurate and lightweight inference models. I hope this tutorial can become a reliable reference for you on the road to lightweighting. You can constantly weigh and iterate in practice to find the deployment solution that best suits your business scenario.