机器学习之决策树剪枝处理
后剪枝的操作顺序是 从叶节点→内部节点→根节点,逐层向上检查每个子树是否需要剪枝,具体步骤如下:确定剪枝候选对象,遍历完整决策树的所有非叶节点(即内部节点),把每个内部节点及其下属的所有分支看作一个 “待剪枝子树”。1)如果 操作 B 的误差 ≤ 操作 A 的误差:说明剪掉该分支后,模型在验证集上的性能没有下降,甚至更好,此时执行剪枝(保留操作 B 的叶节点)。将该子树的根节点(内部节点)直接替换
一、为什么决策树要进行剪枝处理?
决策树剪枝的核心目的,是解决模型过拟合问题,提升其泛化能力—— 让模型不仅在训练数据上表现好,还能在未见过的新数据上稳定预测。过拟合会导致模型泛化能力差完全生长的决策树在训练集上准确率极高,但面对新数据时误差会急剧升高。这是因为模型学到的不是数据的通用规律,而是训练数据的专属特性,这种现象就是过拟合。
剪枝的本质就是去掉这些对泛化没有帮助的冗余分支,让模型更简洁、更鲁棒。
二、剪枝处理的核心思路
- 预剪枝:在树的构建过程中就提前停止生长,比如限制树的最大深度、设置节点的最小样本数。优点是计算效率高,缺点是可能因 “剪得太早” 导致欠拟合。
- 后剪枝:先让树完全生长,再从下往上剪掉对模型性能贡献小的分支。优点是效果更优,不容易欠拟合,缺点是计算成本更高。
三、剪枝处理的代码示例
1.预剪枝
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建预剪枝决策树
pre_pruned_tree = DecisionTreeClassifier(
max_depth=3, # 最大深度
min_samples_split=10, # 最小分裂样本数
min_samples_leaf=5, # 最小叶子节点样本数
min_impurity_decrease=0.01, # 最小不纯度减少量
random_state=42
)
# 训练模型
pre_pruned_tree.fit(X_train, y_train)
# 评估性能
train_score = pre_pruned_tree.score(X_train, y_train)
test_score = pre_pruned_tree.score(X_test, y_test)
print(f"预剪枝决策树 - 训练集准确率: {train_score:.3f}")
print(f"预剪枝决策树 - 测试集准确率: {test_score:.3f}")
# 可视化决策树
plt.figure(figsize=(12, 8))
plot_tree(pre_pruned_tree,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True,
rounded=True)
plt.title("预剪枝决策树")
plt.show()
2.后剪枝
以分类决策树为例(CART 算法的后剪枝流程最具代表性):
1. 前提准备:生成 “完全生长树”
用训练集训练一棵未剪枝的完整决策树,让树生长到极致:
1)每个叶节点的样本都属于同一类别;
2)或没有剩余特征可用于划分;
3)或节点样本数小于设定的最小阈值。
准备一个验证集(用于评估剪枝后的模型性能,不能用训练集,否则会误判过拟合分支的价值)。
2.后剪枝的核心步骤:自底向上回溯剪枝
后剪枝的操作顺序是 从叶节点→内部节点→根节点,逐层向上检查每个子树是否需要剪枝,具体步骤如下:确定剪枝候选对象,遍历完整决策树的所有非叶节点(即内部节点),把每个内部节点及其下属的所有分支看作一个 “待剪枝子树”。
对每个候选子树评估剪枝收益
对每个待剪枝子树,执行两个操作并对比性能:
1)操作 A:保留原分支
用验证集计算该子树对应的预测误差(如分类任务的误分类率:验证集中被分错的样本数 / 总样本数)。
2)操作 B:剪掉分支(子树替换为叶节点)
将该子树的根节点(内部节点)直接替换为叶节点,叶节点的类别由子树内所有样本的多数类决定(或回归任务的均值)。
同样用验证集计算替换后的预测误差。
判断是否剪枝:
1)如果 操作 B 的误差 ≤ 操作 A 的误差:说明剪掉该分支后,模型在验证集上的性能没有下降,甚至更好,此时执行剪枝(保留操作 B 的叶节点)。
2)如果 操作 B 的误差 > 操作 A 的误差:说明该分支对模型泛化有帮助,保留原分支。
逐层向上迭代剪枝
对所有候选子树完成一轮评估和剪枝后,生成一棵新的简化树。
然后对新树重复步骤 1-3,继续向上检查更上层的子树,直到没有任何分支可以被剪枝(即所有分支剪枝后误差都会上升)为止。
3. 最终确定最优剪枝树
经过多轮迭代剪枝后,会得到一系列不同复杂度的剪枝树(每剪一次得到一棵)。
选择验证集误差最小的那棵树作为最终模型;若有多棵树误差相同,优先选择结构最简单(节点数最少、深度最浅) 的树。
四、实验总结
预剪枝是在树的构建过程中提前停止生长(如限制树深度、设置节点最小样本数),优点是计算效率高、实现简单,缺点是容易因 “剪得太早” 导致欠拟合。
它更适合以下情况:
1.中小规模数据集 + 对计算效率要求高的场景
2.数据噪声少、特征规律明显的问题
3.需要快速迭代调参的场景
后剪枝是先生成完整树,再自底向上剪掉冗余分支,优点是泛化能力更强、不容易欠拟合,缺点是计算成本高(需要生成完整树 + 多轮验证评估)。
它更适合以下情况:
1.大规模数据集 + 对模型精度要求高的场景
2.数据噪声多、容易过拟合的问题
3.对模型稳定性要求高的场景
更多推荐


所有评论(0)