用PyTorch/TensorFlow解决经典的MNIST问题
·
📌 前言 MNIST手写数字识别是深度学习与计算机视觉领域最经典的入门任务,涵盖了数据加载、预处理、模型构建、训练及评估的完整流程。本文将从零基础出发,分别使用PyTorch与TensorFlow(Keras)双框架实现MNIST分类,提供完整可复现的代码,帮助开发者快速掌握神经网络实战的基本范式。
🧰 一、环境准备 确保Python版本为3.8+,需安装以下核心依赖包:
pip install torch torchvision tensorflow matplotlib numpy
🖼️ 二、数据加载与预处理 神经网络对输入数据的尺度敏感,需将像素值从 [0, 255] 归一化至 [0, 1],以加速模型收敛。
2.1 PyTorch 数据加载 利用 torchvision 自动下载并加载数据,设置批次大小(Batch Size)为64。
import torch
from torchvision import datasets, transforms
# 定义预处理:转为张量并归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST全局均值与标准差
])
# 加载训练集与测试集
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)
2.2 TensorFlow 数据加载 利用 tf.keras 接口直接加载numpy数组格式的数据,并进行维度扩展与归一化。
import tensorflow as tf
# 加载数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# 维度扩展:(N, 28, 28) -> (N, 28, 28, 1) 并归一化至 [0, 1]
x_train = x_train[..., tf.newaxis].astype('float32') / 255.0
x_test = x_test[..., tf.newaxis].astype('float32') / 255.0
# 构建tf.data.Dataset提升性能
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(1000)
🏗️ 三、模型构建 采用简单的卷积神经网络(CNN)结构:卷积层提取特征,池化层降维,全连接层输出分类结果。
3.1 PyTorch 模型定义
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.fc1 = nn.Linear(9216, 128) # 经过两次卷积与池化后的特征维度
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
model_pt = Net()
3.2 TensorFlow 模型定义
model_tf = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
🚀 四、模型训练 设置交叉熵损失函数和Adam优化器,迭代训练5个Epoch。
4.1 PyTorch 训练流程
optimizer = torch.optim.Adam(model_pt.parameters())
criterion = nn.CrossEntropyLoss()
model_pt.train()
for epoch in range(5):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model_pt(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch: {epoch+1} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.4f}')
4.2 TensorFlow 训练流程
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model_tf.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
model_tf.fit(train_ds, epochs=5)
🧪 五、模型评估与预测 在测试集上验证模型泛化能力,并可视化部分预测结果。
5.1 PyTorch 评估
model_pt.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model_pt(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print(f'\nPyTorch 测试集准确率: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)')
5.2 TensorFlow 评估
test_loss, test_acc = model_tf.evaluate(test_ds, verbose=2)
print(f'\nTensorFlow 测试集准确率: {test_acc:.4f}')
5.3 预测结果可视化 以TensorFlow模型为例,抽取测试集部分样本进行预测并展示。
import matplotlib.pyplot as plt
import numpy as np
# 将logits转为概率
probability_model = tf.keras.Sequential([model_tf, tf.keras.layers.Softmax()])
predictions = probability_model.predict(x_test[:9])
plt.figure(figsize=(6,6))
for i in range(9):
plt.subplot(3,3,i+1)
plt.imshow(x_test[i].reshape(28,28), cmap='gray')
pred_label = np.argmax(predictions[i])
true_label = y_test[i]
color = 'blue' if pred_label == true_label else 'red'
plt.title(f'Pred:{pred_label}', color=color)
plt.axis('off')
plt.tight_layout()
plt.show()
📝 六、总结 本文使用PyTorch与TensorFlow双框架完成了MNIST手写数字识别任务。核心流程可概括为:
- 数据预处理:归一化加速收敛,维度调整适配模型输入。
- 模型构建:采用CNN提取空间特征,完成10分类输出。
- 训练与评估:交叉熵计算损失,Adam优化参数,验证集评估泛化能力。 掌握此基础范式后,可逐步替换更复杂的网络结构或迁移至更难的图像分类数据集。
更多推荐
所有评论(0)