title: Detailed explanation of Siamese Network: similarity learning and face recognition | Daoman PythonAI description: In-depth analysis of the Siamese Network model and its application in tasks such as similarity learning, face recognition, signature verification, etc., including detailed architecture analysis, PyTorch implementation, and practical application scenarios. keywords: [Twin Network, Siamese Network, Similarity Learning, One-Shot Learning, Deep Learning, Computer Vision, PyTorch]

Detailed explanation of Siamese Network: similarity learning and face recognition

Introduction

In traditional deep learning classification scenarios, we need a large amount of labeled data of fixed categories to make the model converge. But in reality, we often face challenges:

  • Does the company have to retrain the face recognition model when adding new employees?
  • There is only one authentic sample of antique calligraphy and painting in the museum’s appraisal?
  • Need to match "niche shoes with similar styles" in the e-commerce search?

Siamese Network jumps out of the logic of "direct classification" and instead learns the distance/similarity between samples, perfectly adapting to this type of problem. This article will unfold one by one from the core principles, PyTorch minimalist implementation, key components to practical techniques.


1. Quick overview of core principles

1.1 The nature of twins: two “identical” sub-networks

The twin network consists of two sub-networks with exactly the same structure and 100% shared parameters, just like twins.

1.2 Process dismantling (understand in seconds by looking at the picture)

flowchart LR
    A[输入样本X1] --> B[子网络1<br/>(特征提取器)] --> C[特征向量f(X1)]
    D[输入样本X2] --> E[子网络2<br/>(=子网络1, 共享权重)] --> F[特征向量f(X2)]
    C & F --> G[距离计算层<br/>欧氏/余弦/曼哈顿]
    G --> H[相似度判断<br/>阈值对比]

1.3 Why use shared weights?

  • Parameters halved: higher training efficiency
  • Feature Space Consistency: Ensuref(X1)andf(X2)Comparable in the same coordinate system
  • Strong generalization: Avoid two sub-networks learning different feature logic

2. PyTorch minimalist implementation

First write a basic version that can run directly on grayscale MNIST, focusing on logic rather than complex architecture.

2.1 Basic twin network

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleSiamese(nn.Module):
    def __init__(self, feature_dim=128):
        super().__init__()
        # 同卵特征提取器(CNN)
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=5), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 256), nn.ReLU(),
            nn.Linear(256, feature_dim), nn.L2Normalize(dim=1)  # L2归一化便于距离计算
        )

    # 单个样本前向传播
    def forward_one(self, x):
        return self.feature_extractor(x)

    # 两个样本前向传播(共享权重)
    def forward(self, x1, x2):
        return self.forward_one(x1), self.forward_one(x2)

2.2 Key supporting components

Contrastive Loss

The core loss of the twin network makes the distance between similar samples smaller and the distance between heterogeneous samples larger (there is no loss if it exceeds margin):

The core idea is simple:

  • When two samples are similar (labellabel=0), directly penalize their feature distance, forcing the distance to approach 0.
  • When two samples are heterogeneous (labellabel=1), only punish those whose distance is smaller thanmarginsituation. In other words, as long as the distance between heterogeneous samples is large enough (more thanmargin), no more penalties will be imposed, and the model can be "lazy" and ignore them.

This part of the logic can be clearly expressed in code:

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super().__init__()
        self.margin = margin

    def forward(self, feat1, feat2, label):
        # 计算欧氏距离(L2归一化后等价于1-余弦相似度的平方根)
        dist = F.pairwise_distance(feat1, feat2, keepdim=True)
        # 计算损失:同类(label=0)罚距离平方;异类(label=1)罚 margin - dist(且不小于0)
        loss = torch.mean(
            (1-label) * torch.pow(dist, 2) +
            label * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2)
        )
        return loss

Fast similarity inference

def is_similar(feat1, feat2, threshold=0.8):
    """
    判断两个特征向量是否相似
    feat1/feat2: L2归一化后的特征向量
    threshold: 余弦相似度阈值(因为L2归一化后余弦更直观)
    """
    cos_sim = F.cosine_similarity(feat1.unsqueeze(0), feat2.unsqueeze(0)).item()
    return cos_sim > threshold, cos_sim

3. Practical core: sample pair construction and data enhancement

3.1 Sample pair construction (key to training)

The input to the twin network is a sample pair, not a single sample. Need to balance positive and negative sample pairs (usually 1:1):

import random

def make_pairs(dataset, num_pairs_per_class=3):
    """
    从PyTorch Dataset中构建样本对
    dataset: 需支持__getitem__返回(img, label)
    """
    pairs = []
    pair_labels = []  # 0=相似,1=不相似

    # 整理每个类别的索引
    class_to_indices = {}
    for idx, (_, label) in enumerate(dataset):
        if label not in class_to_indices:
            class_to_indices[label] = []
        class_to_indices[label].append(idx)
    class_list = list(class_to_indices.keys())

    # 构建样本对
    for c in class_list:
        # 正样本对:同一类
        indices_c = class_to_indices[c]
        for _ in range(num_pairs_per_class):
            if len(indices_c) >= 2:
                i, j = random.sample(indices_c, 2)
                pairs.append([dataset[i][0], dataset[j][0]])
                pair_labels.append(0)
        # 负样本对:不同类
        other_classes = [c_other for c_other in class_list if c_other != c]
        for _ in range(num_pairs_per_class):
            i = random.choice(indices_c)
            c_other = random.choice(other_classes)
            j = random.choice(class_to_indices[c_other])
            pairs.append([dataset[i][0], dataset[j][0]])
            pair_labels.append(1)
    
    return pairs, pair_labels

3.2 Notes on data enhancement

  • You can apply different small amplitude enhancement (such as random brightness) to two images of the same positive sample pair**
  • Don't apply large enhancements that break feature consistency (such as handwritten digits rotated more than 90 degrees)

4. Common practical problems and optimization

4.1 How to choose the threshold?

Don't slap your head! Use the ROC curve of the validation set to find the optimal threshold:

from sklearn.metrics import roc_curve

def find_best_threshold(model, val_loader, device):
    model.eval()
    all_dists = []
    all_labels = []

    with torch.no_grad():
        for x1, x2, label in val_loader:
            x1, x2 = x1.to(device), x2.to(device)
            feat1, feat2 = model(x1, x2)
            dist = F.pairwise_distance(feat1, feat2).cpu().numpy()
            all_dists.extend(dist)
            all_labels.extend(label.numpy())

    # 找最佳欧氏距离阈值(ROC曲线上最接近(0,1)的点)
    fpr, tpr, thresholds = roc_curve(all_labels, all_dists, pos_label=1)
    best_idx = np.argmax(tpr - fpr)
    return thresholds[best_idx]

4.2 What should I do if the reasoning is too slow?

  • Precomputed database features: Save registered faces/signatures/product features to the database/cache without re-extracting them every time
  • Model Quantization: Usetorch.quantizationCompress the model from FP32 to INT8, increasing the speed by 3-4 times
  • ONNX/TensorRT Export: Optimized with dedicated inference engine when deployed to production environment

5. Typical application scenarios

The core of the twin network is "small sample + similarity judgment". Typical scenarios include:

  1. Face recognition/attendance: New employees only need to take 1-3 photos, no retraining is required
  2. Signature/Fingerprint Verification: There are very few authentic samples, and only similarity is compared during verification.
  3. E-commerce same/similar style search: Use images uploaded by users to match products with similar styles in the library
  4. Defect Detection: There are only a small number of normal samples. During detection, the distance between the new sample and the normal sample is compared.

6. Summary

Twin network is a simple but powerful similarity learning architecture, which perfectly solves the pain points of traditional classification in the "small sample, new category" scenario.

Review of core points:

  1. Two identical subnetworks: shared parameters and consistent feature spaces
  2. Contrast loss: compress the distance between similar types and widen the distance between different types.
  3. Sample pair training: Balancing positive and negative samples, the key among the keys

If you need higher accuracy, you can advance to learn Triplet Loss, FaceNet or Transformer-based similarity model.