机器学习实操 第一部分 机器学习基础 第3章 训练分类器

内容概要

第3章深入探讨了分类任务,这是监督学习中的一种常见类型。通过使用MNIST数据集(一个包含70,000张手写数字图像的数据集),本章介绍了如何训练分类器来识别图像中的数字。内容涵盖二元分类、多类分类、性能评估指标(如精确率、召回率、F1分数)、ROC曲线和AUC值等关键概念。此外,还讨论了多标签和多输出分类任务,以及如何通过错误分析来改进分类器。
在这里插入图片描述

主要内容

  1. MNIST数据集

    • 数据加载与预览:通过fetch_openml函数获取MNIST数据集,并展示其基本结构。
    • 数据可视化:使用Matplotlib绘制图像以直观展示数据。
  2. 二元分类

    • 任务简化:将问题简化为仅识别数字5的二元分类任务。
    • 模型训练:使用SGDClassifier训练模型。
    • 性能评估:通过交叉验证评估模型性能,并引入混淆矩阵、精确率、召回率和F1分数等评估指标。
  3. 多类分类

    • 策略介绍:介绍了一对多(OvR)和一对一(OvO)策略。
    • 模型实现:使用SVCSGDClassifier进行多类分类,并分析其性能。
  4. 多标签和多输出分类

    • 多标签分类:训练模型以输出多个标签。
    • 多输出分类:构建系统以去除图像噪声,演示多输出分类任务。
  5. 性能评估指标

    • 精确率与召回率的权衡:通过调整决策阈值来平衡精确率和召回率。
    • ROC曲线与AUC值:使用ROC曲线和AUC值评估分类器性能。
  6. 错误分析

    • 混淆矩阵分析:通过可视化混淆矩阵分析模型错误。
    • 数据增强:通过数据增强技术提高模型性能。

关键代码和算法

3.1 MNIST数据集加载与预览

from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt

mnist = fetch_openml('mnist_784', as_frame=False)
X, y = mnist.data, mnist.target
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

def plot_digit(image_data):
    image = image_data.reshape(28, 28)
    plt.imshow(image, cmap="binary")
    plt.axis("off")

some_digit = X[0]
plot_digit(some_digit)
plt.show()

3.2 二元分类器训练与评估

from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

y_train_5 = (y_train == '5')
y_test_5 = (y_test == '5')

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
cm = confusion_matrix(y_train_5, y_train_pred)
precision = precision_score(y_train_5, y_train_pred)
recall = recall_score(y_train_5, y_train_pred)
f1 = f1_score(y_train_5, y_train_pred)

3.3 多类分类器训练

from sklearn.svm import SVC

svm_clf = SVC(random_state=42)
svm_clf.fit(X_train[:2000], y_train[:2000])  # 仅训练前2000个样本以节省时间

y_train_pred = svm_clf.predict(X_train[:2000])
cm = confusion_matrix(y_train[:2000], y_train_pred)

3.4 多标签分类器训练

from sklearn.neighbors import KNeighborsClassifier
import numpy as np

y_train_large = (y_train >= '7')
y_train_odd = (y_train.astype('int8') % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd]

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)

y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)
f1 = f1_score(y_multilabel, y_train_knn_pred, average="macro")

3.5 多输出分类器训练

np.random.seed(42)
noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train_mod, y_train_mod)

clean_digit = knn_clf.predict([X_test_mod[0]])
plot_digit(clean_digit)
plt.show()

精彩语录

  1. 中文:分类任务的性能评估比回归任务复杂得多,因此本章将重点介绍如何评估分类器的性能。
    英文原文:Evaluating a classifier is often significantly trickier than evaluating a regressor, so we will spend a large part of this chapter on this topic.
    解释:强调了分类任务评估的复杂性。

  2. 中文:在处理不平衡数据集时,准确率并不是一个理想的性能指标。
    英文原文:Accuracy is generally not the preferred performance measure for classifiers, especially when you are dealing with skewed datasets.
    解释:指出了准确率的局限性,并引出混淆矩阵和其他评估指标的必要性。

  3. 中文:精确率和召回率之间存在权衡,提高一个通常会降低另一个。
    英文原文:There is a trade-off between precision and recall: increasing one typically reduces the other.
    解释:总结了精确率和召回率之间的关系。

  4. 中文:F1分数是精确率和召回率的调和平均,适用于需要单个评估指标的场景。
    英文原文:The F score is the harmonic mean of precision and recall, making it suitable for scenarios where a single metric is needed.
    解释:介绍了F1分数的定义及其适用场景。

  5. 中文:ROC曲线和AUC值是评估分类器性能的有力工具,尤其适用于二元分类任务。
    英文原文:The ROC curve and AUC score are powerful tools for evaluating classifier performance, particularly for binary classification tasks.
    解释:强调了ROC曲线和AUC值的重要性。

总结

通过本章的学习,读者将掌握分类任务的核心概念和实践技巧。这些内容包括如何训练二元分类器和多类分类器,如何评估分类器的性能,以及如何通过错误分析和数据增强来改进模型。这些技能对于实际应用中的分类问题至关重要。

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐