本文手把手教你如何编写一个完整的 PyTorch 训练脚本。通过 CIFAR10 图像分类任务,系统演示了从数据集加载、网络结构设计到训练/测试闭环的每一个核心步骤。文章重点分享了训练三步曲、模型状态切换以及利用 TensorBoard 实时监控训练性能的实战技巧,是一篇不可多得的深度学习训练模板指南。

1. 深度学习训练的“标准九步走”

很多初学者面对训练代码会感到头大,其实只要按照这个逻辑拆解,就会非常清晰:

  1. 准备数据:下载并定义 Dataset

  2. 加载数据:用 DataLoader 打包。

  3. 搭建模型:继承 nn.Module(建议封装在 model.py 中)。

  4. 损失函数:定义评价标准(如 CrossEntropyLoss)。

  5. 优化器:选择更新权重的策略(如 SGDAdam)。

  6. 参数设置:确定 Epoch(轮数)和 Learning Rate(学习率)。

  7. 训练循环:前向传播》》计算 Loss 》》梯度清零》》反向传播 》》权重更新。

  8. 验证/测试:在不更新权重的条件下,评估模型性能。

  9. 保存模型:持久化存储,以便后续使用。

2.步骤详解

2.1准备数据集(Dataset)

CIFAR10 简介

  • 图片大小:3 × 32 × 32

  • 训练集:50000 张

  • 测试集:10000 张

  • 分类数:10 类

ToTensor() 会完成两件事:

  • HWC → CHW

  • 像素归一化到 [0, 1]

train_data = torchvision.datasets.CIFAR10(
    root="./CIFAR10",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

test_data = torchvision.datasets.CIFAR10(
    root="./CIFAR10",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

2.2查看数据集规模并加载数据集

train_data_size = len(train_data)
test_data_size = len(test_data)

print("训练集数据长度:{}".format(train_data_size))
print("测试集数据长度: {}".format(test_data_size))

输出:

训练集数据长度:50000 测试集数据长度: 10000
train_dataloader = torch.utils.data.DataLoader(
    train_data,
    batch_size=64,
    shuffle=True,
    drop_last=True
)

test_dataloader = torch.utils.data.DataLoader(
    test_data,
    batch_size=64,
    shuffle=True,
    drop_last=True
)

参数说明

  • batch_size=64:一次训练 64 张图片

  • shuffle=True:打乱数据,防止模型记忆顺序

  • drop_last=True:丢弃不足一个 batch 的数据

2.3搭建神经网络(model.py)

cifar10_net = CIFAR10_Net()

网络结构单独写在 model.py

#搭建神经网络
class CIFAR10_Net(nn.Module):
    def __init__(self):
        super(CIFAR10_Net, self).__init__()
        self.model=nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        output=self.model(x)
        return output

if __name__ == '__main__':
    cifar10_net=CIFAR10_Net()
    input=torch.ones((64,3,32,32))
    output=cifar10_net(input)
    print(output.shape)

这样做的好处:

  • 训练逻辑和模型结构解耦

  • 便于复用与调试

  • 更接近真实项目代码结构

2.4定义损失函数

loss_fn = nn.CrossEntropyLoss()

为什么用 CrossEntropyLoss?

  • 适用于 多分类任务

  • 内部自动完成:

    • Softmax

    • Log

    • NLLLoss

⚠️ 注意:
模型最后一层不要手动加 Softmax

2.5定义优化器

learning_rate = 1e-2
optimizer = torch.optim.SGD(
    cifar10_net.parameters(),
    lr=learning_rate,
    momentum=0.9
)

SGD + Momentum

  • SGD:随机梯度下降

  • momentum=0.9:加速收敛,减少震荡

2.6设置训练参数

total_train_step = 0
total_test_step = 0
epoch = 20
  • epoch:完整遍历训练集的次数

  • total_train_step:记录训练 step,用于 TensorBoard

  • total_test_step:记录测试轮数

3.特别注意

① 训练模式与测试模式

在代码中,使用了 cifar10_net.train()cifar10_net.eval()

  • 物理意义:虽然对简单的 CNN 影响较小,但如果网络中有 DropoutBatchNorm,这两个开关能确保它们在训练和测试时处于正确的工作状态。

② 经典的“训练三部曲”

这是每个 batch 必须做的规定动作:

optimizer.zero_grad() # 1. 擦掉旧的梯度,否则会累加
loss.backward()      # 2. 计算当前误差对权重的梯度
optimizer.step()     # 3. 真正动手修改权重参数

③ 测试逻辑与 torch.no_grad()

在测试环节,千万不能计算梯度,否则会浪费大量的显存和计算资源。

with torch.no_grad(): # 关闭梯度追踪,保证测试时的纯净与高效
    for data in test_dataloader:
        # ... 仅做前向传播

④性能监控:TensorBoard 的重要性

你使用了 writer.add_scalar 来记录 Loss 和准确率。

  • 为什么要画图?:如果 Loss 曲线一直震荡不下降,说明学习率可能过大;如果训练 Loss 下降但测试 Loss 反而上升,说明模型出现了过拟合。通过 TensorBoard 的可视化,我们能一眼看出模型的健康状况。

4. 学习感悟与避坑指南

  1. 准确率的计算outputs.argmax(1) == targets 是最常用的分类准确率计算方法,它找出概率最大的那个类别的索引,并与标签对比。

  2. .item() 的妙用:在打印 Loss 或记录到 TensorBoard 时,务必使用 loss.item()。它可以把 Tensor 转化为 Python 的数值类型,有效防止内存溢出。

  3. 模型保存:代码中每一轮都保存了一个 .pth 文件,这叫 Checkpoints(检查点)。这样即使中途断电,你也能从最新的进度恢复。

5.最终代码

"""
========================================
@FileName:    train
@Author:      ye_shun
@Email:       2942613675@qq.com
@Created:     2026/1/21 17:40
@Description: 完整的模型训练
①准备数据集 ②加载数据集 ③搭建神经网络 ④损失函数 ⑤优化器 ⑥设置训练参数 ⑦开始训练 ⑧测试 ⑨保存模型 (tensorboard画图)
========================================
"""
import torch
import torchvision
from torch import nn
from torch.nn import Flatten
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from model import *
#准备数据集
train_data = torchvision.datasets.CIFAR10(root="./CIFAR10", train=True, download=True, transform=transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10(root="./CIFAR10", train=False, download=True, transform=transforms.ToTensor())

# length长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print("训练集数据长度:{}".format(train_data_size))
print("测试集数据长度: {}".format(test_data_size))

#利用Dataloader来加载数据集
train_dataloader=torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0,drop_last=True)
test_dataloader=torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True,num_workers=0,drop_last=True)


#搭建神经网络,写在model.py里面
cifar10_net = CIFAR10_Net()

#损失函数
loss_fn=nn.CrossEntropyLoss()

#优化器
learning_rate=1e-2
optimizer=torch.optim.SGD(cifar10_net.parameters(), lr=learning_rate, momentum=0.9)

#设置训练网络的一些参数
#设置训练的次数
total_train_step = 0
#记录训练的次数
total_test_step = 0
#设置训练的轮数
epoch=20
#开始训练
cifar10_net.train()
#添加tensorboard
writer=SummaryWriter(log_dir="./logs_train")
for i in range(epoch):
    print("------第{}轮训练开始------".format(i+1))
    for data  in train_dataloader:
        imgs,targets=data
        outputs=cifar10_net(imgs)
        loss=loss_fn(outputs,targets)
        #优化器调优
        optimizer.zero_grad()#梯度清零
        loss.backward()#反向传播
        optimizer.step()#参数优化
        total_train_step+=1
        if total_train_step%100==0:
            print("训练次数:{},Loss:{:.5f}".format(total_train_step,loss.item()))
            writer.add_scalar("train_loss",loss.item(),total_train_step)

    #测试步骤开始
    cifar10_net.eval()
    total_loss=0
    total_accuracy=0
    with torch.no_grad():
        for data in test_dataloader:
            imgs,targets=data
            outputs=cifar10_net(imgs)
            loss=loss_fn(outputs,targets)
            total_loss+=loss.item()
            accuracy=(outputs.argmax(1)==targets).sum()
            total_accuracy+=accuracy.item()
    print("整体测试集上的Loss:{:.5f}".format(total_loss/len(test_dataloader)))
    print("整体测试集上的准确率:{}".format(total_accuracy / len(test_dataloader)))
    writer.add_scalar("test_loss",total_loss/len(test_dataloader),total_test_step)
    total_test_step+=1
    torch.save(cifar10_net, "cifar10_net_{}.pth".format(i+1))
    print("第{}轮模型已保存".format(i+1))
writer.close()

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐