【scikit-learn】sklearn.datasets 模块:数据集加载( 内置数据集、合成数据集和外部数据集接口)
sklearn.datasets是scikit-learn提供的数据集加载模块,包含内置数据集、合成数据集和外部数据集接口,用于机器学习模型的实验和测试。如果需要分类、回归或聚类测试,可使用内置数据集;如果需要定制数据,可使用合成数据集;如果需要真实数据,可使用fetch_openml()获取外部数据。load_iris()鸢尾花分类,load_wine()葡萄酒分类,load_digits()手
·
sklearn.datasets 模块
sklearn.datasets 是 scikit-learn 提供的 数据集加载模块,包含 内置数据集、合成数据集和外部数据集接口,用于 机器学习模型的实验和测试。
1. sklearn.datasets 提供的功能
| 数据集类型 | 作用 | 主要函数 |
|---|---|---|
| 内置数据集 | 经典机器学习数据集 | load_iris()、load_digits()、load_wine() |
| 合成数据集 | 生成随机数据 | make_classification()、make_regression() |
| 外部数据集接口 | 读取开放数据集 | fetch_openml()、fetch_california_housing() |
2. 内置数据集
适用于快速测试机器学习算法,数据已 格式化为 NumPy 数组和 Pandas DataFrame。
(1) 经典分类数据集
① load_iris()(鸢尾花数据集)
from sklearn.datasets import load_iris
# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
# 数据集信息
print("特征名称:", iris.feature_names)
print("类别名称:", iris.target_names)
print("数据形状:", X.shape, y.shape)
数据集信息
- 特征 (
X):sepal length,sepal width,petal length,petal width(4 个特征)。 - 类别 (
y):setosa,versicolor,virginica(3 类)。
② load_wine()(葡萄酒数据集)
from sklearn.datasets import load_wine
wine = load_wine()
X, y = wine.data, wine.target
print("特征名称:", wine.feature_names)
print("类别名称:", wine.target_names)
数据集信息
- 13 个特征(酒精、苹果酸、镁含量等)。
- 3 类红酒。
③ load_digits()(手写数字数据集)
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
digits = load_digits()
X, y = digits.data, digits.target
# 显示一个手写数字
plt.imshow(digits.images[0], cmap="gray")
plt.title(f"Label: {y[0]}")
plt.show()
数据集信息
- 1797 张 8×8 的手写数字图片。
- 10 个类别(数字
0-9)。
(2) 经典回归数据集
① load_diabetes()(糖尿病数据集)
from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target
print("特征名称:", diabetes.feature_names)
print("数据形状:", X.shape, y.shape)
数据集信息
- 442 个样本,10 个数值型特征(年龄、BMI、血压等)。
- 目标变量是连续值(糖尿病病情进展)。
② load_boston()(波士顿房价数据集,已废弃)
sklearn.datasets.load_boston()已被移除,建议使用fetch_california_housing()。
from sklearn.datasets import fetch_california_housing
housing = fetch_california_housing()
X, y = housing.data, housing.target
print("特征名称:", housing.feature_names)
print("数据形状:", X.shape, y.shape)
数据集信息
- 20640 个样本,8 个特征(房屋年龄、房间数等)。
- 目标值是房价中位数。
3. 生成合成数据
用于测试不同类型的机器学习任务(分类、回归、聚类)。
(1) 生成分类数据
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
# 生成数据(1000 个样本,2 个特征,2 个类别)
X, y = make_classification(n_samples=1000, n_features=2, n_classes=2, random_state=42)
# 可视化数据
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="coolwarm")
plt.title("合成分类数据")
plt.show()
参数
n_samples=1000:样本数。n_features=2:特征数。n_classes=2:类别数。
(2) 生成回归数据
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=1000, n_features=1, noise=10, random_state=42)
plt.scatter(X, y)
plt.title("合成回归数据")
plt.show()
参数
n_features=1:1 个特征。noise=10:加入噪声。
(3) 生成聚类数据
from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=300, centers=3, random_state=42)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis")
plt.title("合成聚类数据")
plt.show()
参数
centers=3:3 个聚类中心。
4. 访问外部数据集
(1) fetch_openml()
fetch_openml() 从 OpenML 平台下载数据。
from sklearn.datasets import fetch_openml
mnist = fetch_openml("mnist_784", version=1, as_frame=False)
X, y = mnist.data, mnist.target
print("MNIST 数据形状:", X.shape, y.shape)
数据集
mnist_784:手写数字数据集(70,000 张28×28图片)。- 适用于深度学习、CNN 任务。
(2) fetch_california_housing()
housing = fetch_california_housing()
X, y = housing.data, housing.target
数据集
- 房价预测数据集(
20640个样本,8个特征)。
5. sklearn.datasets 的主要函数
| 函数 | 作用 |
|---|---|
load_iris() |
鸢尾花分类 |
load_wine() |
葡萄酒分类 |
load_digits() |
手写数字分类 |
load_diabetes() |
糖尿病回归 |
fetch_california_housing() |
加州房价预测 |
fetch_openml() |
获取 OpenML 数据集 |
make_classification() |
生成分类数据 |
make_regression() |
生成回归数据 |
make_blobs() |
生成聚类数据 |
6. 适用场景
- 测试机器学习算法(使用
load_iris()、load_digits())。 - 生成合成数据,评估模型性能(使用
make_classification())。 - 获取外部数据集进行实验(使用
fetch_openml())。
7. 结论
sklearn.datasets提供了经典数据集、合成数据集和外部数据集接口,适用于 机器学习实验。- 如果 需要分类、回归或聚类测试,可使用内置数据集;如果 需要定制数据,可使用合成数据集;如果 需要真实数据,可使用
fetch_openml()获取外部数据。
更多推荐

所有评论(0)