从零开始搭建、训练并保存你的 CIFAR10 分类模型
本文详细介绍了使用PyTorch实现CIFAR10图像分类的完整训练流程。主要内容包括:1)标准九步训练法,涵盖数据准备、模型构建到模型保存全过程;2)关键步骤详解,如数据加载、网络结构设计(封装在model.py)、损失函数和优化器选择;3)实战技巧,包括训练/测试模式切换、梯度清零三部曲、TensorBoard可视化监控等。文章特别强调.item()转换、准确率计算和模型检查点保存等实用细节,
本文手把手教你如何编写一个完整的 PyTorch 训练脚本。通过 CIFAR10 图像分类任务,系统演示了从数据集加载、网络结构设计到训练/测试闭环的每一个核心步骤。文章重点分享了训练三步曲、模型状态切换以及利用 TensorBoard 实时监控训练性能的实战技巧,是一篇不可多得的深度学习训练模板指南。
1. 深度学习训练的“标准九步走”
很多初学者面对训练代码会感到头大,其实只要按照这个逻辑拆解,就会非常清晰:
-
准备数据:下载并定义
Dataset。 -
加载数据:用
DataLoader打包。 -
搭建模型:继承
nn.Module(建议封装在model.py中)。 -
损失函数:定义评价标准(如
CrossEntropyLoss)。 -
优化器:选择更新权重的策略(如
SGD或Adam)。 -
参数设置:确定
Epoch(轮数)和Learning Rate(学习率)。 -
训练循环:前向传播》》计算 Loss 》》梯度清零》》反向传播 》》权重更新。
-
验证/测试:在不更新权重的条件下,评估模型性能。
-
保存模型:持久化存储,以便后续使用。
2.步骤详解
2.1准备数据集(Dataset)
CIFAR10 简介
-
图片大小:
3 × 32 × 32 -
训练集:50000 张
-
测试集:10000 张
-
分类数:10 类
ToTensor() 会完成两件事:
-
HWC → CHW -
像素归一化到
[0, 1]
train_data = torchvision.datasets.CIFAR10(
root="./CIFAR10",
train=True,
download=True,
transform=transforms.ToTensor()
)
test_data = torchvision.datasets.CIFAR10(
root="./CIFAR10",
train=False,
download=True,
transform=transforms.ToTensor()
)
2.2查看数据集规模并加载数据集
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练集数据长度:{}".format(train_data_size))
print("测试集数据长度: {}".format(test_data_size))
输出:
训练集数据长度:50000 测试集数据长度: 10000
train_dataloader = torch.utils.data.DataLoader(
train_data,
batch_size=64,
shuffle=True,
drop_last=True
)
test_dataloader = torch.utils.data.DataLoader(
test_data,
batch_size=64,
shuffle=True,
drop_last=True
)
参数说明
-
batch_size=64:一次训练 64 张图片 -
shuffle=True:打乱数据,防止模型记忆顺序 -
drop_last=True:丢弃不足一个 batch 的数据
2.3搭建神经网络(model.py)
cifar10_net = CIFAR10_Net()
网络结构单独写在
model.py中#搭建神经网络 class CIFAR10_Net(nn.Module): def __init__(self): super(CIFAR10_Net, self).__init__() self.model=nn.Sequential( nn.Conv2d(3, 32, 5, 1, 2), nn.MaxPool2d(2, 2), nn.Conv2d(32, 32, 5, 1, 2), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 5, 1, 2), nn.MaxPool2d(2, 2), Flatten(), nn.Linear(64*4*4, 64), nn.Linear(64, 10) ) def forward(self, x): output=self.model(x) return output if __name__ == '__main__': cifar10_net=CIFAR10_Net() input=torch.ones((64,3,32,32)) output=cifar10_net(input) print(output.shape)
这样做的好处:
-
训练逻辑和模型结构解耦
-
便于复用与调试
-
更接近真实项目代码结构
2.4定义损失函数
loss_fn = nn.CrossEntropyLoss()
为什么用 CrossEntropyLoss?
-
适用于 多分类任务
-
内部自动完成:
-
Softmax -
Log -
NLLLoss
-
⚠️ 注意:
模型最后一层不要手动加 Softmax
2.5定义优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(
cifar10_net.parameters(),
lr=learning_rate,
momentum=0.9
)
SGD + Momentum
-
SGD:随机梯度下降 -
momentum=0.9:加速收敛,减少震荡
2.6设置训练参数
total_train_step = 0
total_test_step = 0
epoch = 20
-
epoch:完整遍历训练集的次数 -
total_train_step:记录训练 step,用于 TensorBoard -
total_test_step:记录测试轮数
3.特别注意
① 训练模式与测试模式
在代码中,使用了 cifar10_net.train() 和 cifar10_net.eval()。
-
物理意义:虽然对简单的 CNN 影响较小,但如果网络中有
Dropout或BatchNorm层,这两个开关能确保它们在训练和测试时处于正确的工作状态。
② 经典的“训练三部曲”
这是每个 batch 必须做的规定动作:
optimizer.zero_grad() # 1. 擦掉旧的梯度,否则会累加
loss.backward() # 2. 计算当前误差对权重的梯度
optimizer.step() # 3. 真正动手修改权重参数
③ 测试逻辑与 torch.no_grad()
在测试环节,千万不能计算梯度,否则会浪费大量的显存和计算资源。
with torch.no_grad(): # 关闭梯度追踪,保证测试时的纯净与高效
for data in test_dataloader:
# ... 仅做前向传播
④性能监控:TensorBoard 的重要性
你使用了 writer.add_scalar 来记录 Loss 和准确率。
-
为什么要画图?:如果 Loss 曲线一直震荡不下降,说明学习率可能过大;如果训练 Loss 下降但测试 Loss 反而上升,说明模型出现了过拟合。通过 TensorBoard 的可视化,我们能一眼看出模型的健康状况。
4. 学习感悟与避坑指南
-
准确率的计算:
outputs.argmax(1) == targets是最常用的分类准确率计算方法,它找出概率最大的那个类别的索引,并与标签对比。 -
.item() 的妙用:在打印 Loss 或记录到 TensorBoard 时,务必使用
loss.item()。它可以把 Tensor 转化为 Python 的数值类型,有效防止内存溢出。 -
模型保存:代码中每一轮都保存了一个
.pth文件,这叫 Checkpoints(检查点)。这样即使中途断电,你也能从最新的进度恢复。
5.最终代码
"""
========================================
@FileName: train
@Author: ye_shun
@Email: 2942613675@qq.com
@Created: 2026/1/21 17:40
@Description: 完整的模型训练
①准备数据集 ②加载数据集 ③搭建神经网络 ④损失函数 ⑤优化器 ⑥设置训练参数 ⑦开始训练 ⑧测试 ⑨保存模型 (tensorboard画图)
========================================
"""
import torch
import torchvision
from torch import nn
from torch.nn import Flatten
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from model import *
#准备数据集
train_data = torchvision.datasets.CIFAR10(root="./CIFAR10", train=True, download=True, transform=transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10(root="./CIFAR10", train=False, download=True, transform=transforms.ToTensor())
# length长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print("训练集数据长度:{}".format(train_data_size))
print("测试集数据长度: {}".format(test_data_size))
#利用Dataloader来加载数据集
train_dataloader=torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0,drop_last=True)
test_dataloader=torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True,num_workers=0,drop_last=True)
#搭建神经网络,写在model.py里面
cifar10_net = CIFAR10_Net()
#损失函数
loss_fn=nn.CrossEntropyLoss()
#优化器
learning_rate=1e-2
optimizer=torch.optim.SGD(cifar10_net.parameters(), lr=learning_rate, momentum=0.9)
#设置训练网络的一些参数
#设置训练的次数
total_train_step = 0
#记录训练的次数
total_test_step = 0
#设置训练的轮数
epoch=20
#开始训练
cifar10_net.train()
#添加tensorboard
writer=SummaryWriter(log_dir="./logs_train")
for i in range(epoch):
print("------第{}轮训练开始------".format(i+1))
for data in train_dataloader:
imgs,targets=data
outputs=cifar10_net(imgs)
loss=loss_fn(outputs,targets)
#优化器调优
optimizer.zero_grad()#梯度清零
loss.backward()#反向传播
optimizer.step()#参数优化
total_train_step+=1
if total_train_step%100==0:
print("训练次数:{},Loss:{:.5f}".format(total_train_step,loss.item()))
writer.add_scalar("train_loss",loss.item(),total_train_step)
#测试步骤开始
cifar10_net.eval()
total_loss=0
total_accuracy=0
with torch.no_grad():
for data in test_dataloader:
imgs,targets=data
outputs=cifar10_net(imgs)
loss=loss_fn(outputs,targets)
total_loss+=loss.item()
accuracy=(outputs.argmax(1)==targets).sum()
total_accuracy+=accuracy.item()
print("整体测试集上的Loss:{:.5f}".format(total_loss/len(test_dataloader)))
print("整体测试集上的准确率:{}".format(total_accuracy / len(test_dataloader)))
writer.add_scalar("test_loss",total_loss/len(test_dataloader),total_test_step)
total_test_step+=1
torch.save(cifar10_net, "cifar10_net_{}.pth".format(i+1))
print("第{}轮模型已保存".format(i+1))
writer.close()
更多推荐
所有评论(0)