一、前言

        在深度学习领域,数据是一项宝贵的资源。训练一个合格的模型往往需要大量的图片、文本数据。

        用全存在标签的数据来训练模型,叫做监督学习,而对于全是无标签的数据来学习,叫无监督学习。现实中,往往没有那么多的数据被打上既定的标签,而对于其中混合的大量无标签数据,我们也不能浪费,于是,半监督学习应运而生。

        半监督学习介于监督学习和无监督学习之间,本文采用的是半监督学习中最经典的伪标签策略,来加大数据的利用率,逻辑如下:

        ①通过模型来生成无标签数据的预测值

        ②筛选出其中置信度高于阈值的(如0.99),将其设置为伪标签

        ③把这些带着伪标签的数据和原有的标签数据进行混合训练

        需要注意的是,置信度阈值不可以太低,否则反而会生成很多垃圾数据,污染数据库,干扰模型训练。


二、实现步骤

        下面我将分模块来展示代码,并对其进行解释

2.1固定随机种子

        这一步的目的是保证实验可以复现,当我的模型的效果很好时,我可以通过这个固定的随机种子复现。

import random
import torch
import torch.nn as nn
import numpy as np
import os

def seed_everything(seed):
    # 固定CPU/GPU随机种子
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # 固定NumPy随机种子
    random.seed(seed)
    np.random.seed(seed)
    # 固定哈希值,避免字典、集合等哈希相关操作的随机性
    os.environ['PYTHONHASHSEED'] = str(seed)

# 固定种子为1
seed_everything(1)

2.2数据预处理

        这一步,主要是对数据集做数据增广,提升模型的泛化能力。现实生活中,真实提供的数据、图片往往会有多种形式,这一步可以将我们的数据集也经过多种形式的变换,以提升模型应对真实场景的多样性的表现。

# 图片尺寸
HW = 224  

# 训练集数据增广:
train_transform = transforms.Compose(
    [
        transforms.ToPILImage(), 
        # 随机裁剪+缩放
        transforms.RandomResizedCrop(224),  
        # 随机旋转±50°
        transforms.RandomRotation(50),      
        transforms.ToTensor()               
    ]
)

# 验证集基础变换:仅做格式转换+转张量
val_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.ToTensor()
    ]
)

        为什么验证集不做数据增广?验证集的目的是检验模型的能力

        这里的验证集其实就是对应在真实条件下的各种场景,我们在训练时,需要提升泛化能力,验证时,用原始数据进行检验即可。

2.3 定义数据集类

        数据集需要适配以下不同模式:

        有标签还是无标签?

        训练集还是验证集?(对应的数据增广不同)

from torch.utils.data import Dataset

class food_Dataset(Dataset):
    def __init__(self, path, mode="train"):
        self.mode = mode
        # 无标签数据(半监督模式)
        if mode == "semi":
            self.X = self.read_file(path)
        # 有标签数据(训练/验证模式)
        else:
            # # 有标签模式下,读取图片self.X和标签self.Y
            self.X, self.Y = self.read_file(path)
            self.Y = torch.LongTensor(self.Y)  
        # 选择对应的预处理变换
        self.transform = train_transform if mode == "train" else val_transform
    
    #读取数据
    def read_file(self, path):
        if self.mode == "semi":
            file_list = os.listdir(path)
            # uint8对应0-255像素
            xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
            for j, img_name in enumerate(file_list):
                # 合并路径
                img_path = os.path.join(path, img_name)
                # 打开并统一缩放图像尺寸
                img = Image.open(img_path).resize((HW, HW))
                xi[j, ...] = img
            print(f"读到了{len(xi)}个无标签数据")
            return xi
        else:
            # 读取有标签数据:按类别分文件夹
            for i in tqdm(range(11)):
                # 保证文件夹名是两位数(00/01/.../10)
                file_dir = path + "/%02d" % i
                file_list = os.listdir(file_dir)
                xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
                yi = np.zeros(len(file_list), dtype=np.uint8)
                for j, img_name in enumerate(file_list):
                    img_path = os.path.join(file_dir, img_name)
                    img = Image.open(img_path).resize((HW, HW))
                    xi[j, ...] = img
                     # 标签为当前类别
                    yi[j] = i 
                # 拼接所有类别的数据
                if i == 0:
                    X = xi
                    Y = yi
                else:
                    X = np.concatenate((X, xi), axis=0)
                    Y = np.concatenate((Y, yi), axis=0)
            print(f"读到了{len(Y)}个有标签数据")
            return X, Y
  
  # 取数据
    def __getitem__(self, item):
        if self.mode == "semi":
            # 半监督模式:预处理后的图像+原始图像(用于生成伪标签)
            return self.transform(self.X[item]), self.X[item]
        else:
            # 有标签模式:预处理后的图像+ 标签
            return self.transform(self.X[item]), self.Y[item]

 # 返回数据集总样本数
    def __len__(self):
        return len(self.X)

         注意,需要把标签数据转为LongTensor,这是交叉熵损失的输入要求。

2.4生成伪标签

        这个半监督学习项目的核心,用训练后的模型对无标签数据进行预测,筛选出其中置信度大于0.99的样本作为伪标签数据,并加以利用。

# 半监督数据加载器
def get_semi_loader(no_label_loder, model, device, thres):
    semiset = semiDataset(no_label_loder, model, device, thres)
    # 无符合条件则返回None,不更新伪标签数据
    return DataLoader(semiset, batch_size=16, shuffle=False) if semiset.flag else None

# 半监督数据集——打标签
class semiDataset(Dataset):
    def __init__(self, no_label_loder, model, device, thres=0.99):
        # 生成伪标签
        x, y = self.get_label(no_label_loder, model, device, thres)
        # 标记是否存在有效伪标签数据
        self.flag = False if x == [] else True  
        if self.flag:
            self.X = np.array(x)
            self.Y = torch.LongTensor(y)
            self.transform = train_transform

    def get_label(self, no_label_loder, model, device, thres):
        model.eval()  
        pred_prob, labels = [], []
        # 将模型输出转为概率分布
        soft = nn.Softmax(dim=1)  
        with torch.no_grad(): 
            for bat_x, _ in no_label_loder:
                bat_x = bat_x.to(device)
                pred = model(bat_x)
                pred_soft = soft(pred)
                # 获取每个样本的最大置信度和对应标签
                pred_max, pred_label = pred_soft.max(1)
                pred_prob.extend(pred_max.cpu().numpy().tolist())
                labels.extend(pred_label.cpu().numpy().tolist())
    
        # 筛选置信度>阈值的样本
        x, y = [], []
        for idx, prob in enumerate(pred_prob):
            if prob > thres:
                x.append(no_label_loder.dataset[idx][1])  # 原始图像
                y.append(labels[idx])
        return x, y

    def __getitem__(self, item):
        return self.transform(self.X[item]), self.Y[item]

    def __len__(self):
        return len(self.X)

2.5定义训练神经网络

目的:将3*224*224  ——> 512*7*7——>拉直展平(25088)——>全连接分类——>得到11类

结构遵循经典设计:卷积→BN→激活→池化


class myModel(nn.Module):
    def __init__(self, num_class):
        super(myModel, self).__init__()
        # 第一层卷积:3通道→64通道,尺寸不变
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        # 批量归一化:加速训练+提升稳定性
        self.bn1 = nn.BatchNorm2d(64)  
        self.relu1 = nn.ReLU()
        # 下采样:尺寸减半
        self.pool1 = nn.MaxPool2d(2)  

        # 模块化卷积层:卷积→BN→激活→池化
        self.layer1 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.pool2 = nn.MaxPool2d(2)
        # 全连接层:展平特征→分类
        self.fc1 = nn.Linear(25088, 1000)  
        self.relu2 = nn.ReLU()
         # 分类头:1000维→11类
        self.fc2 = nn.Linear(1000, num_class) 

    def forward(self, x):
        # 前向传播流程
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool2(x)
         # 展平:(batch, 512,7,7)→(batch, 25088)
        x = x.view(x.size()[0], -1) 
        x = self.fc1(x)
        x = self.relu2(x)
        x = self.fc2(x)
        return x

全连接层将展平后的特征映射到 11 个类别,完成分类任务

2.6 训练验证函数

        整个项目的训练核心逻辑,融合了监督训练和半监督训练的逻辑

        伪标签更新核心逻辑:每过三个轮次,且验证准确率大于0.6时,更新伪标签数据,防止模型还没有一个良好的效果时,就用大量伪标签数据来污染训练。同时也是给设备减负,每轮都来时间太久了。

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

def train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path):
    model = model.to(device) 
    # 初始化半监督数据加载器
    semi_loader = None  
    # 记录训练/验证的损失和准确率
    plt_train_loss, plt_val_loss = [], []
    plt_train_acc, plt_val_acc = [], []
    # 记录最高验证准确率
    max_acc = 0.0 


    for epoch in range(epochs):
        # 初始化本轮损失/准确率
        train_loss = val_loss = 0.0
        train_acc = val_acc = 0.0
        semi_loss = semi_acc = 0.0
        start_time = time.time()

        # 监督训练
        model.train()  
        for batch_x, batch_y in train_loader:
            x, target = batch_x.to(device), batch_y.to(device)
            pred = model(x)
            loss_batch = loss(pred, target)
            # 反向传播+参数更新
            loss_batch.backward()
            optimizer.step()
             # 梯度清零,避免累积
            optimizer.zero_grad() 
            # 累计损失和准确率
            train_loss += loss_batch.cpu().item()
            # 计算正确预测数:argmax取预测类别,与真实标签对比
            train_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())

        # 半监督训练
        if semi_loader is not None:
            for batch_x, batch_y in semi_loader:
                x, target = batch_x.to(device), batch_y.to(device)
                pred = model(x)
                loss_batch = loss(pred, target)
                loss_batch.backward()
                optimizer.step()
                optimizer.zero_grad()
                semi_loss += loss_batch.cpu().item()
                semi_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())
            print(f"半监督训练准确率:{semi_acc/len(semi_loader.dataset):.4f}")

        # 验证集评估
        model.eval() 
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                x, target = batch_x.to(device), batch_y.to(device)
                pred = model(x)
                loss_batch = loss(pred, target)
                val_loss += loss_batch.cpu().item()
                val_acc += np.sum(np.argmax(pred.detach().cpu().numpy(), axis=1) == target.cpu().numpy())

        # 计算并记录本轮指标
        plt_train_loss.append(train_loss / len(train_loader))
        plt_train_acc.append(train_acc / len(train_loader.dataset))
        plt_val_loss.append(val_loss / len(val_loader.dataset))
        plt_val_acc.append(val_acc / len(val_loader.dataset))

        #  条件更新伪标签:    
        # 每3个epoch+验证acc>0.6
        if epoch % 3 == 0 and plt_val_acc[-1] > 0.6:
            semi_loader = get_semi_loader(no_label_loader, model, device, thres)

        # 保存最优模型
        if val_acc > max_acc:
            torch.save(model, save_path)
            max_acc = val_acc

        # 打印本轮训练结果
        print(f'[{epoch+1}/{epochs}] {time.time()-start_time:.2f}s | '
              f'TrainLoss: {plt_train_loss[-1]:.6f} | ValLoss: {plt_val_loss[-1]:.6f} | '
              f'TrainAcc: {plt_train_acc[-1]:.6f} | ValAcc: {plt_val_acc[-1]:.6f}')

    # 可视化
    plt.plot(plt_train_loss, label='Train Loss')
    plt.plot(plt_val_loss, label='Val Loss')
    plt.title('Loss Curve')
    plt.legend()
    plt.show()

    plt.plot(plt_train_acc, label='Train Acc')
    plt.plot(plt_val_acc, label='Val Acc')
    plt.title('Accuracy Curve')
    plt.legend()
    plt.show()

2.7 参数配置与训练启动!!!

# 数据集路径
train_path = "路径"
val_path = "路径"
no_label_path = "路径"

#  加载数据集
train_set = food_Dataset(train_path, "train")
val_set = food_Dataset(val_path, "val")
no_label_set = food_Dataset(no_label_path, "semi")

# 封装为DataLoader,训练集打乱,验证、无标签集不打乱
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=False)
no_label_loader = DataLoader(no_label_set, batch_size=16, shuffle=False)

# 超参数与模型配置
model = myModel()  
lr = 0.001  
# 交叉熵损失函数
loss_fn = nn.CrossEntropyLoss() 
# AdamW优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
# 有GPU就用GPU,大大提升性能
device = "cuda" if torch.cuda.is_available() else "cpu"

# 训练配置 
save_path = "模型保存路径"  
epochs = 15  
thres = 0.99 

# 训练!!!启动!!!!
train_val(model, train_loader, val_loader, no_label_loader,
          device, epochs, optimizer, loss_fn, thres, save_path)

三、再次梳理半监督训练全流程

        ① 初始化模型         →         加载有标签数据和无标签数据

        ② 监督训练            →         用有标签数据训练模型

        ③ 验证集评估         →        if epoch % 3 == 0 and plt_val_acc[-1] > 0.6

        若否,则不更新伪标签数据,继续使用存在的伪标签数据(如果之前生成过的话),进入⑤

        如是,则进入④

        ④ 生成为标签数据,筛选出置信度大于0.99的无标签样本,加入伪标签数据集

        ⑤ 训练模型,保存最优模型,循环轮次

四、总结

        我在这个项目中学到了很多,最重要的就是半监督学习这一概念。同时还有数据预处理的数据增广、以及学到了什么是迁移学习。

        我会不断深入学习的

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐