【scikit-learn】sklearn.model_selection 模块:模型选择模块,用于 数据划分、超参数调优和交叉验证
sklearn.model_selection是scikit-learn提供的模型选择模块,用于数据划分、超参数调优和交叉验证,帮助优化机器学习模型。适用于分类、回归、聚类任务,可结合GridSearchCV、KFold、learning_curve等方法优化模型。train_test_split数据集拆分(训练集/测试集),cross_val_score交叉验证评分,StratifiedKFol
·
sklearn.model_selection 模块
sklearn.model_selection 是 scikit-learn 提供的 模型选择模块,用于 数据划分、超参数调优和交叉验证,帮助优化机器学习模型。
1. sklearn.model_selection 提供的主要功能
| 方法 | 作用 | 适用场景 |
|---|---|---|
train_test_split |
数据集拆分(训练集/测试集) | 训练和测试模型 |
cross_val_score |
交叉验证评分 | 评估模型性能 |
StratifiedKFold |
分层 K 折交叉验证 | 类别数据不均衡 |
GridSearchCV |
网格搜索超参数优化 | 选择最佳超参数 |
RandomizedSearchCV |
随机搜索超参数优化 | 适用于大参数空间 |
KFold |
K 折交叉验证 | 评估模型稳定性 |
LeaveOneOut |
留一交叉验证(LOO) | 小数据集 |
ShuffleSplit |
随机划分交叉验证 | 适用于大数据集 |
learning_curve |
学习曲线 | 观察模型是否过拟合 |
validation_curve |
验证曲线 | 选择合适的模型复杂度 |
2. train_test_split(数据集拆分)
用于 将数据集拆分为训练集和测试集,常用于 模型训练和评估。
代码示例
from sklearn.model_selection import train_test_split
import numpy as np
# 示例数据
X = np.arange(10).reshape((5, 2))
y = np.array([0, 1, 0, 1, 0])
# 拆分数据(80% 训练集,20% 测试集)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("训练集:\n", X_train)
print("测试集:\n", X_test)
输出
训练集:
[[6 7]
[2 3]
[0 1]
[4 5]]
测试集:
[[8 9]]
主要参数
| 参数 | 说明 |
|---|---|
test_size |
测试集比例(0.2 代表 20%) |
random_state |
设置随机种子,确保结果可复现 |
stratify |
按 y 分层抽样(适用于类别不均衡) |
3. cross_val_score(交叉验证评分)
用于 评估模型的泛化能力。
代码示例
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
# 训练随机森林,并进行 5 折交叉验证
model = RandomForestClassifier()
scores = cross_val_score(model, X, y, cv=5)
print("交叉验证得分:", scores)
print("平均得分:", scores.mean())
输出
交叉验证得分: [0.96 0.98 0.94 0.96 0.96]
平均得分: 0.96
主要参数
| 参数 | 说明 |
|---|---|
cv |
交叉验证的折数(默认为 5) |
scoring |
评估指标(如 accuracy, f1, roc_auc) |
4. GridSearchCV(网格搜索超参数优化)
用于 遍历所有可能的超参数组合,找到最优参数。
代码示例
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
# 设定参数网格
param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}
# 进行网格搜索
grid_search = GridSearchCV(SVC(), param_grid, cv=5)
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)
print("最佳得分:", grid_search.best_score_)
输出
最佳参数: {'C': 1, 'kernel': 'linear'}
最佳得分: 0.98
主要参数
| 参数 | 说明 |
|---|---|
param_grid |
需要搜索的超参数 |
cv |
交叉验证折数 |
scoring |
评估指标 |
5. RandomizedSearchCV(随机搜索超参数优化)
在大参数空间中更高效,随机选择部分参数进行搜索。
代码示例
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform
# 设定参数分布
param_dist = {'C': uniform(0.1, 10), 'kernel': ['linear', 'rbf']}
# 进行随机搜索
random_search = RandomizedSearchCV(SVC(), param_dist, n_iter=5, cv=5, random_state=42)
random_search.fit(X, y)
print("最佳参数:", random_search.best_params_)
适用于
- 大规模超参数搜索
- 参数范围较大
6. KFold(K 折交叉验证)
用于 将数据集划分为 K 份,每次用 K-1 份训练,1 份测试。
代码示例
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in kf.split(X):
print("训练集索引:", train_index, "测试集索引:", test_index)
适用于
- 数据量较少
- 提高模型的泛化能力
7. StratifiedKFold(分层 K 折交叉验证)
用于类别不均衡数据,确保每折中类别比例相同。
代码示例
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in skf.split(X, y):
print("训练集索引:", train_index, "测试集索引:", test_index)
8. LeaveOneOut(留一交叉验证,LOO)
每次用 N-1 个样本训练,1 个样本测试。
from sklearn.model_selection import LeaveOneOut
loo = LeaveOneOut()
for train_index, test_index in loo.split(X):
print("训练集索引:", train_index, "测试集索引:", test_index)
适用于
- 小数据集
- 极端精确评估
9. learning_curve(学习曲线)
用于判断模型是否过拟合或欠拟合。
from sklearn.model_selection import learning_curve
import matplotlib.pyplot as plt
train_sizes, train_scores, test_scores = learning_curve(SVC(), X, y, cv=5)
plt.plot(train_sizes, train_scores.mean(axis=1), label="训练集")
plt.plot(train_sizes, test_scores.mean(axis=1), label="测试集")
plt.legend()
plt.show()
10. validation_curve(验证曲线)
观察模型在不同超参数下的表现。
from sklearn.model_selection import validation_curve
param_range = [0.1, 1, 10]
train_scores, test_scores = validation_curve(SVC(), X, y, param_name="C", param_range=param_range, cv=5)
plt.plot(param_range, train_scores.mean(axis=1), label="训练集")
plt.plot(param_range, test_scores.mean(axis=1), label="测试集")
plt.legend()
plt.show()
11. 结论
sklearn.model_selection提供了数据划分、超参数优化、交叉验证等工具,是 提高模型性能的关键。- 适用于 分类、回归、聚类 任务,可结合
GridSearchCV、KFold、learning_curve等方法 优化模型。
更多推荐

所有评论(0)