sklearn.datasets 模块

sklearn.datasetsscikit-learn 提供的 数据集加载模块,包含 内置数据集、合成数据集和外部数据集接口,用于 机器学习模型的实验和测试


1. sklearn.datasets 提供的功能

数据集类型 作用 主要函数
内置数据集 经典机器学习数据集 load_iris()load_digits()load_wine()
合成数据集 生成随机数据 make_classification()make_regression()
外部数据集接口 读取开放数据集 fetch_openml()fetch_california_housing()

2. 内置数据集

适用于快速测试机器学习算法,数据已 格式化为 NumPy 数组和 Pandas DataFrame

(1) 经典分类数据集

load_iris()(鸢尾花数据集)
from sklearn.datasets import load_iris

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 数据集信息
print("特征名称:", iris.feature_names)
print("类别名称:", iris.target_names)
print("数据形状:", X.shape, y.shape)

数据集信息

  • 特征 (X)sepal length, sepal width, petal length, petal width(4 个特征)。
  • 类别 (y)setosa, versicolor, virginica(3 类)。

load_wine()(葡萄酒数据集)
from sklearn.datasets import load_wine

wine = load_wine()
X, y = wine.data, wine.target

print("特征名称:", wine.feature_names)
print("类别名称:", wine.target_names)

数据集信息

  • 13 个特征(酒精、苹果酸、镁含量等)。
  • 3 类红酒

load_digits()(手写数字数据集)
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt

digits = load_digits()
X, y = digits.data, digits.target

# 显示一个手写数字
plt.imshow(digits.images[0], cmap="gray")
plt.title(f"Label: {y[0]}")
plt.show()

数据集信息

  • 1797 张 8×8 的手写数字图片
  • 10 个类别(数字 0-9

(2) 经典回归数据集

load_diabetes()(糖尿病数据集)
from sklearn.datasets import load_diabetes

diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target

print("特征名称:", diabetes.feature_names)
print("数据形状:", X.shape, y.shape)

数据集信息

  • 442 个样本,10 个数值型特征(年龄、BMI、血压等)。
  • 目标变量是连续值(糖尿病病情进展)

load_boston()(波士顿房价数据集,已废弃)

sklearn.datasets.load_boston() 已被移除,建议使用 fetch_california_housing()

from sklearn.datasets import fetch_california_housing

housing = fetch_california_housing()
X, y = housing.data, housing.target

print("特征名称:", housing.feature_names)
print("数据形状:", X.shape, y.shape)

数据集信息

  • 20640 个样本,8 个特征(房屋年龄、房间数等)。
  • 目标值是房价中位数

3. 生成合成数据

用于测试不同类型的机器学习任务(分类、回归、聚类)

(1) 生成分类数据

from sklearn.datasets import make_classification
import matplotlib.pyplot as plt

# 生成数据(1000 个样本,2 个特征,2 个类别)
X, y = make_classification(n_samples=1000, n_features=2, n_classes=2, random_state=42)

# 可视化数据
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="coolwarm")
plt.title("合成分类数据")
plt.show()

参数

  • n_samples=1000:样本数。
  • n_features=2:特征数。
  • n_classes=2:类别数。

(2) 生成回归数据

from sklearn.datasets import make_regression

X, y = make_regression(n_samples=1000, n_features=1, noise=10, random_state=42)

plt.scatter(X, y)
plt.title("合成回归数据")
plt.show()

参数

  • n_features=1:1 个特征。
  • noise=10:加入噪声。

(3) 生成聚类数据

from sklearn.datasets import make_blobs

X, y = make_blobs(n_samples=300, centers=3, random_state=42)

plt.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis")
plt.title("合成聚类数据")
plt.show()

参数

  • centers=3:3 个聚类中心。

4. 访问外部数据集

(1) fetch_openml()

fetch_openml() 从 OpenML 平台下载数据

from sklearn.datasets import fetch_openml

mnist = fetch_openml("mnist_784", version=1, as_frame=False)
X, y = mnist.data, mnist.target

print("MNIST 数据形状:", X.shape, y.shape)

数据集

  • mnist_784:手写数字数据集(70,000 张 28×28 图片)。
  • 适用于深度学习、CNN 任务

(2) fetch_california_housing()

housing = fetch_california_housing()
X, y = housing.data, housing.target

数据集

  • 房价预测数据集20640 个样本,8 个特征)。

5. sklearn.datasets 的主要函数

函数 作用
load_iris() 鸢尾花分类
load_wine() 葡萄酒分类
load_digits() 手写数字分类
load_diabetes() 糖尿病回归
fetch_california_housing() 加州房价预测
fetch_openml() 获取 OpenML 数据集
make_classification() 生成分类数据
make_regression() 生成回归数据
make_blobs() 生成聚类数据

6. 适用场景

  • 测试机器学习算法(使用 load_iris()load_digits())。
  • 生成合成数据,评估模型性能(使用 make_classification())。
  • 获取外部数据集进行实验(使用 fetch_openml())。

7. 结论

  • sklearn.datasets 提供了经典数据集、合成数据集和外部数据集接口,适用于 机器学习实验
  • 如果 需要分类、回归或聚类测试,可使用内置数据集;如果 需要定制数据,可使用合成数据集;如果 需要真实数据,可使用 fetch_openml() 获取外部数据
Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐