Siamese-triplet实战:FashionMNIST和MNIST数据集上的性能对比

【免费下载链接】siamese-triplet Siamese and triplet networks with online pair/triplet mining in PyTorch 【免费下载链接】siamese-triplet 项目地址: https://gitcode.com/gh_mirrors/si/siamese-triplet

Siamese网络和Triplet网络是深度学习中用于学习特征嵌入的强大技术,它们能够将图像映射到紧凑的欧几里得空间中,使得相似样本的距离更近,不相似样本的距离更远。本文将通过实战分析,对比这两种网络在FashionMNIST和MNIST数据集上的性能表现,揭示在线负样本挖掘技术的优势。

为什么选择Siamese和Triplet网络?

在传统的分类任务中,我们通常使用softmax交叉熵损失函数训练网络,但这种方法学习到的特征嵌入可能不具备良好的度量特性。Siamese网络和Triplet网络通过对比学习的方式,直接优化特征空间中的距离关系,从而学习到更具判别力的特征表示。

核心概念解析

Siamese网络:接受一对图像作为输入,训练目标是使同类样本的距离最小化,不同类样本的距离大于某个边界值。它使用对比损失函数(Contrastive Loss)来优化。

Triplet网络:接受三元组(锚点、正样本、负样本)作为输入,目标是让锚点与正样本的距离小于锚点与负样本的距离加上一个边界值。它使用三元组损失函数(Triplet Loss)来优化。

实验设置与数据集对比

MNIST数据集

MNIST是一个相对简单的数据集,包含0-9的手写数字图像,共10个类别。虽然它常用于基准测试,但由于类别间差异明显,网络相对容易学习。

FashionMNIST数据集

FashionMNIST是Zalando推出的服装图像数据集,同样包含10个类别,但类别间差异更细微,复杂度更高。这使得它成为评估度量学习方法的更好测试平台。

两个实验都使用相同的网络架构:32通道5x5卷积 → PReLU → 2x2最大池化 → 64通道5x5卷积 → PReLU → 2x2最大池化 → 256全连接 → PReLU → 256全连接 → PReLU → 2维输出。

性能对比分析

基础分类方法(Softmax)

首先,我们使用传统的softmax分类作为基线。虽然MNIST上能达到99%的准确率,但学习到的2维嵌入在特征空间中缺乏良好的度量特性。

MNIST Softmax训练集特征分布 MNIST数据集上Softmax分类器的训练集特征分布

MNIST Softmax测试集特征分布 MNIST数据集上Softmax分类器的测试集特征分布

Siamese网络 vs 在线对比损失

传统Siamese网络:随机选择正负样本对进行训练。在MNIST数据集上,这种方法的嵌入已经相当不错。

MNIST Siamese网络训练集特征分布 MNIST数据集上Siamese网络的训练集特征分布

在线对比损失(Online Contrastive Loss):在mini-batch内进行负样本挖掘,更高效地利用计算资源。

MNIST在线对比损失测试集特征分布 MNIST数据集上在线对比损失的测试集特征分布

在FashionMNIST数据集上,在线负样本挖掘的优势更加明显:

FashionMNIST Siamese网络特征分布 FashionMNIST数据集上传统Siamese网络的特征分布

FashionMNIST在线对比损失特征分布 FashionMNIST数据集上在线对比损失的特征分布

Triplet网络 vs 在线三元组损失

传统Triplet网络:随机选择三元组进行训练。学习到的嵌入在同类样本中不如Siamese网络紧密,但优化目标不同。

在线三元组损失(Online Triplet Loss):在mini-batch内选择困难三元组,显著提高了训练效率。

在FashionMNIST数据集上的对比更加明显:

FashionMNIST Triplet网络特征分布 FashionMNIST数据集上传统Triplet网络的特征分布

FashionMNIST在线三元组损失特征分布 FashionMNIST数据集上在线三元组损失的特征分布

关键技术优势:在线负样本挖掘

在线负样本挖掘技术解决了传统Siamese和Triplet网络的几个关键问题:

  1. 计算效率:可能的样本对/三元组数量随样本数呈二次/三次增长,在线挖掘避免了处理所有组合
  2. 样本质量:随着训练进行,越来越多的样本对变得"简单",在线挖掘能持续提供困难样本
  3. 计算复用:每个图像嵌入可以复用于多个样本对/三元组计算

实战代码结构

项目的核心代码模块包括:

  • datasets.py - 包含SiameseMNIST和TripletMNIST数据集包装器
  • networks.py - 实现嵌入网络、分类网络、Siamese网络和Triplet网络
  • losses.py - 提供对比损失、三元组损失及其在线版本
  • trainer.py - 统一的训练函数fit
  • metrics.py - 评估指标
  • utils.py - 样本对和三元组选择器

实验结果总结

通过对MNIST和FashionMNIST数据集的对比实验,我们发现:

  1. 数据集复杂度影响:在简单的MNIST数据集上,各种方法差异不大;但在更复杂的FashionMNIST数据集上,在线负样本挖掘的优势明显
  2. 嵌入质量:Siamese网络学习到的同类嵌入更紧密,Triplet网络更注重相对距离关系
  3. 训练效率:在线负样本挖掘显著提高了训练效率,特别是在类别数较多的复杂数据集上
  4. 泛化能力:在线方法在测试集上表现出更好的泛化性能

应用建议

对于实际应用,我们建议:

  1. 简单数据集:可以使用传统Siamese或Triplet网络
  2. 复杂数据集:强烈推荐使用在线负样本挖掘技术
  3. 类别数多:在线方法的优势更加明显
  4. 计算资源有限:在线方法能更高效地利用计算资源

未来展望

该项目仍在积极开发中,未来的改进方向包括:

  • 优化三元组选择策略
  • 开发可比较的评估指标
  • 在更复杂的数据集上测试
  • 支持少样本学习场景

通过本文的实战分析,我们可以看到Siamese-triplet网络在度量学习中的强大能力,特别是在线负样本挖掘技术对复杂数据集处理的重要性。无论是学术研究还是工业应用,这些技术都为特征学习和相似性度量提供了有力的工具。

【免费下载链接】siamese-triplet Siamese and triplet networks with online pair/triplet mining in PyTorch 【免费下载链接】siamese-triplet 项目地址: https://gitcode.com/gh_mirrors/si/siamese-triplet

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐