用深度学习搞事情,模型搭建和训练是绕不开的两步。而 PyTorch,作为一个“又灵活又好用”的深度学习框架,简直就是写代码的快乐源泉。今天我们就从 0 到 1,实战 PyTorch 的模型搭建和训练流程。说白了,看完你就能自己搭个神经网络,喂点数据进去,再让它干点活。

安装 PyTorch

要用 PyTorch,得先装上它。PyTorch 的安装稍微有点讲究,主要是要根据你的硬件选择 CPU 版本还是 GPU 版本。

基本安装命令

如果你只用 CPU,就这么装:

pip install torch torchvision

如果你有 NVIDIA 的 GPU(并且装了 CUDA),可以用下面的命令:

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118

温馨提示:

GPU 能大大加速训练速度,但得先确保你的显卡支持 CUDA,并且驱动版本合格。

数据准备

模型训练需要数据,就像人学技能需要教材。PyTorch提供了torch.utils.data模块,让你可以轻松加载和管理数据。

使用内置数据集

为了简单,我用PyTorch自带的 MNIST 数据集(手写数字)。我们可以用torchvision.datasets直接下载并加载它。

import torch
from torchvision import datasets, transforms

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 标准化到 [-1, 1]
])

# 加载训练和测试数据
train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

# 数据加载器
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

运行后,MNIST数据集会被下载到./data目录下,同时数据会被转换成PyTorch的张量格式。

温馨提示:

batch_size 是每次喂给模型的数据量,太小会导致训练慢,太大则可能爆显存。


资源分享

为了方便大家学习,我整理了PyTorch全套学习资料,包含配套教程讲义和源码

除此之外还有100G人工智能学习资料

包含数学与Python编程基础、深度学习+机器学习入门到实战,计算机视觉+自然语言处理+大模型资料合集

不仅有配套教程讲义还有对应源码数据集。

更有零基础入门学习路线,不论你处于什么阶段,这份资料都能帮助你更好地入门到进阶。

需要的兄弟可以按照这个图的方式免费获取


 


搭建模型

模型就是一个“函数”,它接收输入(比如图片),输出预测结果(比如是个7)。在 PyTorch里,可以通过继承torch.nn.Module来定义自己的模型。

定义一个简单的神经网络

import torch.nn as nn

# 定义模型
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # 输入层到隐藏层
        self.fc2 = nn.Linear(128, 10)       # 隐藏层到输出层

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # 展平输入 (28x28 -> 784)
        x = torch.relu(self.fc1(x))  # 激活函数 ReLU
        x = self.fc2(x)
        return x

这里的forward方法定义了模型的前向传播过程,而torch.relu是一个激活函数,用来引入非线性。

温馨提示:

view(-1, 28 * 28) 是用来把二维图片展平成一维向量,方便传入全连接层。

定义损失函数和优化器

模型需要优化,就得有“目标”和“规则”。在 PyTorch 里,目标是损失函数,规则是优化器。

定义损失函数

分类任务通常用交叉熵损失函数:

loss_fn = nn.CrossEntropyLoss()

交叉熵会计算预测结果和真实标签之间的误差,误差越小,模型就越靠谱。

定义优化器

优化器负责调整模型的参数,让它越来越接近正确答案。我们用经典的 Adam 优化器:

import torch.optim as optim

model = SimpleNN()  # 初始化模型
optimizer = optim.Adam(model.parameters(), lr=0.001)

lr 是学习率,决定了每次参数调整的步伐大小。

温馨提示:

学习率太大模型会不稳定,太小则训练太慢。一般从0.001开始调。

模型训练

训练过程就是:喂数据 -> 计算预测值 -> 计算损失 -> 调整参数。我们用循环来完成这个流程。

训练代码:

# 训练模型
epochs = 5
for epoch in range(epochs):
    model.train()  # 设置为训练模式
    total_loss = 0

    for images, labels in train_loader:
        optimizer.zero_grad()  # 清空上一步的梯度
        outputs = model(images)  # 前向传播
        loss = loss_fn(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")

运行后,你会看到每个epoch的损失值逐渐下降,说明模型学得还不错。

温馨提示:

每次训练前用model.train(),告诉 PyTorch 这是训练阶段(会启用 Dropout 等机制)。

模型评估

模型训练完了,得看看它到底行不行。我们用测试集来评估准确率。

测试代码:

model.eval()  # 设置为评估模式
correct = 0
total = 0

with torch.no_grad():  # 不计算梯度
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)  # 获取预测结果
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

运行后,你会看到模型在测试集上的准确率,比如97.85%,还挺靠谱。

温馨提示:

🔒 测试时用torch.no_grad(),可以节省内存和计算资源。

保存和加载模型

训练好的模型可以保存下来,下次直接加载就能用。

保存模型

torch.save(model.state_dict(), "model.pth")

这会把模型的参数保存到model.pth 文件里。

加载模型

model = SimpleNN()  # 重新初始化模型
model.load_state_dict(torch.load("model.pth"))

加载后,模型就恢复到了训练好的状态,可以直接用来预测。

常见问题和解决办法

  • 训练慢:尝试使用 GPU 加速,只需在代码开头加一句device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),然后把模型和数据都转到device上。

  • 损失不下降:检查数据是否标准化,学习率是否合适,或者试试更复杂的模型结构。

  • 过拟合:如果训练集表现很好,但测试集表现差,可以加一些正则化手段,比如 Dropout。

PyTorch 的模型搭建和训练其实没那么复杂,大概流程就是:准备数据 -> 定义模型 -> 训练模型 -> 评估效果。熟悉这些步骤后,你就可以胜任大部分深度学习任务了!

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐