sklearn.model_selection 模块

sklearn.model_selectionscikit-learn 提供的 模型选择模块,用于 数据划分、超参数调优和交叉验证,帮助优化机器学习模型。


1. sklearn.model_selection 提供的主要功能

方法 作用 适用场景
train_test_split 数据集拆分(训练集/测试集) 训练和测试模型
cross_val_score 交叉验证评分 评估模型性能
StratifiedKFold 分层 K 折交叉验证 类别数据不均衡
GridSearchCV 网格搜索超参数优化 选择最佳超参数
RandomizedSearchCV 随机搜索超参数优化 适用于大参数空间
KFold K 折交叉验证 评估模型稳定性
LeaveOneOut 留一交叉验证(LOO) 小数据集
ShuffleSplit 随机划分交叉验证 适用于大数据集
learning_curve 学习曲线 观察模型是否过拟合
validation_curve 验证曲线 选择合适的模型复杂度

2. train_test_split(数据集拆分)

用于 将数据集拆分为训练集和测试集,常用于 模型训练和评估

代码示例

from sklearn.model_selection import train_test_split
import numpy as np

# 示例数据
X = np.arange(10).reshape((5, 2))
y = np.array([0, 1, 0, 1, 0])

# 拆分数据(80% 训练集,20% 测试集)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print("训练集:\n", X_train)
print("测试集:\n", X_test)

输出

训练集:
[[6 7]
 [2 3]
 [0 1]
 [4 5]]
测试集:
[[8 9]]

主要参数

参数 说明
test_size 测试集比例(0.2 代表 20%
random_state 设置随机种子,确保结果可复现
stratify y 分层抽样(适用于类别不均衡)

3. cross_val_score(交叉验证评分)

用于 评估模型的泛化能力

代码示例

from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 训练随机森林,并进行 5 折交叉验证
model = RandomForestClassifier()
scores = cross_val_score(model, X, y, cv=5)

print("交叉验证得分:", scores)
print("平均得分:", scores.mean())

输出

交叉验证得分: [0.96 0.98 0.94 0.96 0.96]
平均得分: 0.96

主要参数

参数 说明
cv 交叉验证的折数(默认为 5
scoring 评估指标(如 accuracy, f1, roc_auc

4. GridSearchCV(网格搜索超参数优化)

用于 遍历所有可能的超参数组合,找到最优参数

代码示例

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

# 设定参数网格
param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}

# 进行网格搜索
grid_search = GridSearchCV(SVC(), param_grid, cv=5)
grid_search.fit(X, y)

print("最佳参数:", grid_search.best_params_)
print("最佳得分:", grid_search.best_score_)

输出

最佳参数: {'C': 1, 'kernel': 'linear'}
最佳得分: 0.98

主要参数

参数 说明
param_grid 需要搜索的超参数
cv 交叉验证折数
scoring 评估指标

5. RandomizedSearchCV(随机搜索超参数优化)

在大参数空间中更高效,随机选择部分参数进行搜索。

代码示例

from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform

# 设定参数分布
param_dist = {'C': uniform(0.1, 10), 'kernel': ['linear', 'rbf']}

# 进行随机搜索
random_search = RandomizedSearchCV(SVC(), param_dist, n_iter=5, cv=5, random_state=42)
random_search.fit(X, y)

print("最佳参数:", random_search.best_params_)

适用于

  • 大规模超参数搜索
  • 参数范围较大

6. KFold(K 折交叉验证)

用于 将数据集划分为 K 份,每次用 K-1 份训练,1 份测试

代码示例

from sklearn.model_selection import KFold

kf = KFold(n_splits=5, shuffle=True, random_state=42)

for train_index, test_index in kf.split(X):
    print("训练集索引:", train_index, "测试集索引:", test_index)

适用于

  • 数据量较少
  • 提高模型的泛化能力

7. StratifiedKFold(分层 K 折交叉验证)

用于类别不均衡数据,确保每折中类别比例相同。

代码示例

from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for train_index, test_index in skf.split(X, y):
    print("训练集索引:", train_index, "测试集索引:", test_index)

8. LeaveOneOut(留一交叉验证,LOO)

每次用 N-1 个样本训练,1 个样本测试

from sklearn.model_selection import LeaveOneOut

loo = LeaveOneOut()
for train_index, test_index in loo.split(X):
    print("训练集索引:", train_index, "测试集索引:", test_index)

适用于

  • 小数据集
  • 极端精确评估

9. learning_curve(学习曲线)

用于判断模型是否过拟合或欠拟合

from sklearn.model_selection import learning_curve
import matplotlib.pyplot as plt

train_sizes, train_scores, test_scores = learning_curve(SVC(), X, y, cv=5)

plt.plot(train_sizes, train_scores.mean(axis=1), label="训练集")
plt.plot(train_sizes, test_scores.mean(axis=1), label="测试集")
plt.legend()
plt.show()

10. validation_curve(验证曲线)

观察模型在不同超参数下的表现

from sklearn.model_selection import validation_curve

param_range = [0.1, 1, 10]
train_scores, test_scores = validation_curve(SVC(), X, y, param_name="C", param_range=param_range, cv=5)

plt.plot(param_range, train_scores.mean(axis=1), label="训练集")
plt.plot(param_range, test_scores.mean(axis=1), label="测试集")
plt.legend()
plt.show()

11. 结论

  • sklearn.model_selection 提供了数据划分、超参数优化、交叉验证等工具,是 提高模型性能的关键
  • 适用于 分类、回归、聚类 任务,可结合 GridSearchCVKFoldlearning_curve 等方法 优化模型
Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐