Semantic Segmentation: Detailed explanation of pixel-level image understanding and U-Net architecture

Introduction

Semantic segmentation is one of the core tasks in computer vision - it is not just "identifying what is in the picture", but also labeling each pixel with a unique semantic label to accurately outline the outline and spatial distribution of the object. From organ segmentation in medical imaging to road perception in autonomous driving, this technology is an indispensable foundation.

📂 Stage: Stage 2 - Deep Learning Vision Basics (CNN) 🔗 Related chapters: YOLO 家族实战 · 关键点检测 (Keypoints)


1. Basic concepts of semantic segmentation

The input of semantic segmentation is an image with height H, width W, and number of channels C, and the output is a segmentation map with height H, width W, and depth N (N is the number of predefined categories). The network assigns each pixel position (i,j) a class label from the set {1,2,...,N}.

The core differences between it and other visual tasks are as follows:

Task TypeOutput FormCore Objectives
Image classificationSingle category labelIdentify the entire image content
Object detection[category, bounding box] listLocate and identify objects
Semantic Segmentation[pixel × category] matrixPixel-level classification (similar instances are not distinguished)
Instance segmentation[pixel × (category, instance ID)] matrixPixel-level classification + distinguishing different instances of the same type
Panoramic segmentationSame as above (distinguish "thing/stuff")Unified pixel-level scene understanding

1.2 Core application scenarios

The implementation scenarios of semantic segmentation are very wide:

  • Medical Imaging: Organ/tumor segmentation, pathological slice analysis
  • Autonomous Driving: Road/Lane/Obstacle Segmentation
  • Remote sensing images: land use classification, urban planning, environmental monitoring
  • Smart Agriculture: Crop/Pest and Disease Monitoring, Yield Estimation
  • Robot: environment understanding, grasping and positioning
  • Fashion/Entertainment: Clothing segmentation, virtual fitting, film and television post-production cutout

2. Classic semantic segmentation architecture

2.1 FCN: the pioneering work of fully convolutional network

FCN (Fully Convolutional Networks, 2015) is a milestone in semantic segmentation, achieving end-to-end pixel-level prediction for the first time.

Core Contribution

  1. Fully convolutional design: The fully connected layer of the classification network is removed and replaced with a convolutional layer, which supports input of any size;
  2. Deconvolution upsampling: Use transposed convolution (Transposed Convolution) to gradually restore spatial resolution;
  3. Skip Connections: Fusion of the low-level detail features of the encoder and the high-level semantic features of the decoder to solve the problem of detail loss after upsampling.

PyTorch implementation (FCN-8s)

import torch
import torch.nn as nn
import torch.nn.functional as F

class FCN8s(nn.Module):
    def __init__(self, num_classes=21):
        super().__init__()
        # 使用预训练VGG16作为编码器
        vgg16 = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)
        self.features = vgg16.features
        
        # 分类头替换为卷积层
        self.fc_conv = nn.Sequential(
            nn.Conv2d(512, 4096, 7, padding=0),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(4096, 4096, 1),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(4096, num_classes, 1)
        )
        
        # 跳跃连接的1x1卷积
        self.score_pool4 = nn.Conv2d(512, num_classes, 1)
        self.score_pool3 = nn.Conv2d(256, num_classes, 1)
        
        # 上采样层
        self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, 4, 2, bias=False)
        self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, 16, 8, bias=False)

    def forward(self, x):
        input_size = x.shape[2:]
        # 编码器特征提取
        pool3 = self.features[:17](x)  # pool3: 1/8 分辨率
        pool4 = self.features[17:24](pool3)  # pool4: 1/16 分辨率
        pool5 = self.features[24:](pool4)  # pool5: 1/32 分辨率
        
        # 高层特征上采样 + 跳跃连接
        score_fc = self.fc_conv(pool5)
        upscore2 = self.upscore2(score_fc)
        score_pool4 = self.score_pool4(pool4)
        fuse4 = upscore2 + score_pool4
        
        upscore4 = self.upscore2(fuse4)
        score_pool3 = self.score_pool3(pool3)
        fuse3 = upscore4 + score_pool3
        
        # 最终上采样到输入尺寸
        return F.interpolate(fuse3, size=input_size, mode='bilinear', align_corners=False)

2.2 U-Net: The “gold standard” for medical segmentation

U-Net (2015) was originally designed for biomedical image segmentation and has become one of the most commonly used infrastructures in the field of segmentation due to its symmetrical U-shaped structure and efficient skip connections.

Core Features

  1. Symmetric encoder-decoder: The encoder downsamples to extract semantics, and the decoder upsamples to restore spatial resolution;
  2. Concatenate skip connection: Concatenate encoder features and decoder features directly instead of element-by-element addition of FCN, retaining more details;
  3. Small data set friendly: It can achieve good results even on a small amount of labeled data.

PyTorch implementation

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super().__init__()
        features = init_features
        
        # 编码器(下采样路径)
        self.enc1 = self._conv_block(in_channels, features)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.enc2 = self._conv_block(features, features*2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.enc3 = self._conv_block(features*2, features*4)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.enc4 = self._conv_block(features*4, features*8)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        # 瓶颈层
        self.bottleneck = self._conv_block(features*8, features*16)
        
        # 解码器(上采样路径)
        self.upconv4 = nn.ConvTranspose2d(features*16, features*8, 2, 2)
        self.dec4 = self._conv_block(features*16, features*8)
        self.upconv3 = nn.ConvTranspose2d(features*8, features*4, 2, 2)
        self.dec3 = self._conv_block(features*8, features*4)
        self.upconv2 = nn.ConvTranspose2d(features*4, features*2, 2, 2)
        self.dec2 = self._conv_block(features*4, features*2)
        self.upconv1 = nn.ConvTranspose2d(features*2, features, 2, 2)
        self.dec1 = self._conv_block(features*2, features)
        
        # 输出层
        self.outconv = nn.Conv2d(features, out_channels, 1)

    def forward(self, x):
        # 编码器
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
        enc4 = self.enc4(self.pool3(enc3))
        
        # 瓶颈
        bottleneck = self.bottleneck(self.pool4(enc4))
        
        # 解码器 + 跳跃连接
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.dec4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.dec3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.dec2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.dec1(dec1)
        
        return self.outconv(dec1)

    @staticmethod
    def _conv_block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

2.3 DeepLab: Atrous convolution and multi-scale modeling

The core of the DeepLab series (2016-2018) is Atrous Convolution (Atrous Convolution), which expands the receptive field without reducing the resolution, and introduces ASPP (Atrous Space Pyramid Pooling) to capture multi-scale contextual information.

Core components: ASPP

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, atrous_rates=[6, 12, 18]):
        super().__init__()
        modules = []
        # 1x1 卷积分支
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ))
        # 不同空洞率的3x3卷积
        for rate in atrous_rates:
            modules.append(nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            ))
        # 全局平均池化分支
        modules.append(nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ))
        
        self.convs = nn.ModuleList(modules)
        # 特征融合投影
        self.project = nn.Sequential(
            nn.Conv2d(len(self.convs)*out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        # 上采样全局池化结果
        res[-1] = F.interpolate(res[-1], size=res[0].shape[2:], mode='bilinear', align_corners=False)
        return self.project(torch.cat(res, dim=1))

3. Semantic segmentation loss function

Segmentation tasks often face category imbalance (such as a very low proportion of tumor pixels in medical images), so in addition to standard cross-entropy, there are also the following dedicated losses:

Loss functionApplicable scenariosCore ideas
Cross EntropyClass-balanced datasetStandard pixel-wise classification loss
Dice LossSparse foreground/small target segmentationOptimize the overlap rate between prediction and label
Focal LossCategory imbalance + too many difficult-to-distinguish samplesReduce the weight of easy-to-distinguish samples and focus on difficult examples
Lovász LossDirectly optimize IoU indicatorsSmooth approximation of IoU

Common loss code implementation

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs).view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
        return 1 - dice

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        return (self.alpha * (1-pt)**self.gamma * ce_loss).mean()

# 组合损失:平衡交叉熵和Dice
class CombinedLoss(nn.Module):
    def __init__(self, weight_ce=1.0, weight_dice=1.0):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.dice = DiceLoss()
        self.weight_ce = weight_ce
        self.weight_dice = weight_dice

    def forward(self, inputs, targets):
        return self.weight_ce * self.ce(inputs, targets) + self.weight_dice * self.dice(inputs, targets)

4. Data preprocessing and enhancement

The key to the segmentation task is that the image and the mask must be transformed simultaneously. It is recommended to useAlbumentationsLibrary (built-in sync transformation support).

Dedicated data augmentation strategy

import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    A.Resize(512, 512),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.5),
    A.OneOf([A.OpticalDistortion(), A.GridDistortion()], p=0.3),
    A.OneOf([A.CLAHE(), A.RandomBrightnessContrast()], p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

Custom data set class

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class SegDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.imgs = os.listdir(img_dir)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.imgs[idx])
        mask_path = os.path.join(self.mask_dir, self.imgs[idx])
        img = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path), dtype=np.int64)
        
        if self.transform:
            aug = self.transform(image=img, mask=mask)
            img, mask = aug["image"], aug["mask"]
        return img, mask

5. Model training and evaluation

Core training process

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for imgs, masks in tqdm(loader):
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_miou = 0.0
    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        total_loss += loss.item() * imgs.size(0)
        # 计算mIoU
        preds = torch.argmax(outputs, dim=1)
        total_miou += compute_miou(preds, masks, num_classes=21) * imgs.size(0)
    return total_loss / len(loader.dataset), total_miou / len(loader.dataset)

Core evaluation indicators: mIoU

def compute_miou(preds, targets, num_classes):
    ious = []
    for cls in range(num_classes):
        pred_cls = preds == cls
        target_cls = targets == cls
        intersection = (pred_cls & target_cls).sum().item()
        union = (pred_cls | target_cls).sum().item()
        if union == 0:
            continue
        ious.append(intersection / union)
    return np.mean(ious) if ious else 0.0

  1. Transformer empowerment: Hybrid/pure Transformer architectures such as SegFormer, Swin-Unet, and TransUNet have more advantages in long-distance modeling;
  2. Real-time segmentation: BiSeNet, DFANet, Fast-SCNN and other lightweight architectures balance speed and accuracy and adapt to mobile/autonomous driving scenarios;
  3. Large unified model: such as Mask2Former, which unifies semantic/instance/panoramic segmentation tasks.

Semantic segmentation is an advanced task of deep learning vision. It is recommended to master the basics of CNN and image classification first. You can start with U-Net + small medical/remote sensing data sets, and then gradually explore advanced architectures such as DeepLab and SegFormer.

7. Summary

The core of semantic segmentation is pixel-level classification. The evolution of the classic architecture revolves around "restoring spatial resolution" and "fusion of multi-scale information":

  • FCN pioneers full convolution and skip connection;
  • U-Net uses a symmetrical U-shaped structure as a universal basis;
  • DeepLab introduces atrous convolution and ASPP to solve multi-scale problems.

🔗 Extended reading