残差网络实战:基于MNIST数据集的手写数字识别
MNIST数据集是机器学习领域中非常经典的图像数据集,它包含了70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。这些图像均为灰度图,尺寸是28×28像素,并且已经进行了居中处理,大大减少了预处理的工作量,同时也加快了模型的运行速度。在本文的实战中,MNIST数据集将作为我们训练和测试残差网络的“战场”。# 模块搭建# 网络搭建return x在上述代码中,首先定义了R
残差网络实战:基于MNIST数据集的手写数字识别
在深度学习的广阔领域中,卷积神经网络(CNN)一直是处理图像任务的主力军。随着研究的深入,网络层数的增加虽然理论上能提升模型的表达能力,但却面临梯度消失、梯度爆炸以及网络退化等问题。残差网络(ResNet)的出现,成功地解决了这些难题,为深度学习的发展开辟了新的道路。本文将通过在MNIST数据集上进行手写数字识别的实战,带大家深入了解残差网络的原理与应用。
一、MNIST数据集介绍
MNIST数据集是机器学习领域中非常经典的图像数据集,它包含了70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。这些图像均为灰度图,尺寸是28×28像素,并且已经进行了居中处理,大大减少了预处理的工作量,同时也加快了模型的运行速度。在本文的实战中,MNIST数据集将作为我们训练和测试残差网络的“战场”。
二、残差网络原理
传统的神经网络在增加层数时,容易出现网络退化现象,即随着网络层数的增加,模型在训练集和测试集上的性能反而下降。残差网络通过引入残差块(Residual Block)巧妙地解决了这一问题。
残差块的核心思想是让网络学习输入与输出之间的残差,而不是直接学习复杂的映射关系。在一个残差块中,输入数据会经过一系列的卷积、激活等操作,得到一个输出,同时输入数据会直接通过一个快捷连接(Shortcut Connection)与输出相加。这样,网络学习的目标就变成了输出与输入之间的差异,使得训练过程更加容易。
数学上,假设残差块的输入为xxx,期望的输出映射为H(x)H(x)H(x),通过残差块学习到的函数为F(x)F(x)F(x),那么残差块的输出可以表示为y=F(x)+xy = F(x) + xy=F(x)+x。当F(x)=0F(x) = 0F(x)=0时,残差块的输出就等于输入,这保证了即使增加网络层数,模型的性能也不会下降,反而有机会通过学习残差来提升性能。
三、代码实现与解析
1. 数据加载与预处理
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# 下载训练数据集(包含训练图片+标签)
training_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# 下载测试数据集(包含训练图片+标签)
test_data = datasets.MNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
# 创建数据DataLoader(数据加载器)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
上述代码利用torchvision库中的datasets.MNIST类下载MNIST数据集,并使用ToTensor()将图像数据转换为PyTorch能够处理的张量格式。然后,通过DataLoader将数据集划分为大小为64的批次,这样做可以减少内存的使用,提高训练速度。
2. 设备配置
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
这段代码用于判断当前设备是否支持GPU(CUDA或苹果M系列芯片的MPS),如果支持则使用GPU进行计算,否则使用CPU,充分利用硬件资源加速模型训练。
3. 残差块与网络定义
# 模块搭建
class ResBlock(nn.Module):
def __init__(self, channels_in):
super().__init__()
self.conv1 = torch.nn.Conv2d(channels_in, 30, 5, padding=2)
self.conv2 = torch.nn.Conv2d(30, channels_in, 3, padding=1)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
return F.relu(out + x)
# 网络搭建
class ResNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5)
self.conv2 = torch.nn.Conv2d(20, 15, 3)
self.maxpool = torch.nn.MaxPool2d(2)
self.resblock1 = ResBlock(channels_in=20)
self.resblock2 = ResBlock(channels_in=15)
self.full_c = torch.nn.Linear(375, 10)
def forward(self, x):
size = x.shape[0]
x = F.relu(self.maxpool(self.conv1(x)))
x = self.resblock1(x)
x = F.relu(self.maxpool(self.conv2(x)))
x = self.resblock2(x)
x = x.view(size, -1)
x = self.full_c(x)
return x
model = ResNet().to(device)
在上述代码中,首先定义了ResBlock类,实现了残差块的结构,其中包含两个卷积层,并通过快捷连接将输入与卷积层的输出相加,再经过ReLU激活函数得到最终输出。接着,ResNet类构建了完整的残差网络,包含普通卷积层、最大池化层、残差块以及全连接层,将输入图像逐步提取特征并分类为10个数字类别。
4. 训练与测试函数
def train(dataloader, model, loss_fn, optimizer):
model.train()
batch_size_num = 1
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model.forward(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = loss.item()
print(f"loss: {loss:>7f} [number:{batch_size_num}]")
batch_size_num += 1
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model.forward(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")
train函数用于模型的训练过程,在每个批次中,将数据传入模型得到预测结果,通过交叉熵损失函数计算损失,进行反向传播更新模型参数,并打印每一批次的损失值。test函数则用于在测试集上评估模型性能,关闭梯度计算以节省内存,计算测试集上的平均损失和准确率。
5. 模型训练与评估
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 10
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)
最后,定义交叉熵损失函数和Adam优化器,设置训练轮数为10,通过循环调用train函数进行模型训练,训练完成后调用test函数在测试集上评估模型性能。
四、总结与展望
通过在MNIST数据集上的实战,我们成功地实现了基于残差网络的手写数字识别。残差网络凭借其独特的结构设计,有效地解决了深度神经网络中的退化问题,使得我们能够构建更深、更强大的模型。在未来的研究和应用中,残差网络的思想可以应用到更多复杂的图像任务,如图像分割、目标检测等。同时,结合其他先进的技术,如注意力机制、生成对抗网络等,有望进一步提升模型的性能,为深度学习在计算机视觉领域的发展带来更多的可能性。
希望本文能帮助大家更好地理解残差网络,并在实际项目中灵活运用,开启深度学习图像识别的新征程!
更多推荐


所有评论(0)