决策树全解析:从基础到实战(Python代码+案例演示)
用Python实现机器学习决策树算法
·
决策树原理及公式推导
一、决策树基本原理
决策树是一种基于树形结构的监督学习算法,通过递归划分特征空间实现分类或回归任务。其核心机制如下:
-
树形结构定义
- 内部节点(Non-leaf Node):表示对某一特征的判定规则(例如“特征Xj≤θXj≤θ”)。
- 分支(Branch):对应特征判定的输出结果(例如“真”或“假”)。
- 叶节点(Leaf Node):输出最终预测值(分类任务的类别标签或回归任务的连续值)。
-
划分目标
每次划分选择使子节点数据“纯度”最大化的特征及阈值。纯度的量化指标包括信息熵、基尼不纯度或均方误差(回归任务)。
二、数学原理与公式推导
-
信息熵(Information Entropy)
Ent(D)=−∑k=1Kpklog2pkEnt(D)=−k=1∑Kpklog2pk
信息熵度量数据集DD中类别分布的不确定性。- pkpk:数据集中第kk类样本的比例。
- 当所有样本属于同一类时,Ent(D)=0Ent(D)=0;类别分布越均匀,熵值越大。
-
信息增益(Information Gain)
Gain(D,X)=Ent(D)−∑i=1V∣Di∣∣D∣⋅Ent(Di)Gain(D,X)=Ent(D)−i=1∑V∣D∣∣Di∣⋅Ent(Di)
信息增益衡量使用特征XX划分数据集后纯度的提升量,用于ID3算法。- VV:特征XX的可能取值数。
- ∣Di∣∣Di∣:特征XX取第ii个值时对应的子集样本数。
- 缺点:倾向选择取值较多的特征(需通过信息增益率修正)。
-
基尼不纯度(Gini Impurity)
Gini(D)=1−∑k=1Kpk2Gini(D)=1−k=1∑Kpk2
基尼不纯度衡量数据集DD中随机抽样样本类别不一致的概率,用于CART算法。- 基尼不纯度越小,数据纯度越高。
- 与信息熵相比,计算效率更高(无需对数运算)。
-
CART回归树的划分准则
MSE(D)=1∣D∣∑i∈D(yi−yˉ)2MSE(D)=∣D∣1i∈D∑(yi−yˉ)2
回归任务中,使用均方误差(MSE)或平均绝对误差(MAE)作为划分指标。- yˉyˉ:数据集DD中样本输出的均值。
三、算法流程与关键技术
-
特征选择策略
- ID3算法:最大化信息增益,仅支持分类任务。
- C4.5算法:引入信息增益率(Gain_Ratio=Gain(D,X)IV(X)Gain_Ratio=IV(X)Gain(D,X)),其中IV(X)IV(X)为特征XX的固有值(Intrinsic Value),缓解对多值特征的偏好。
- CART算法:
- 分类任务:最小化基尼不纯度。
- 回归任务:最小化均方误差。
- 支持二叉树结构,递归生成左右子树。
-
递归终止条件
- 节点中样本数小于预设阈值。
- 节点纯度达到要求(如基尼不纯度低于阈值)。
- 达到树的最大深度限制。
-
剪枝优化(Pruning)
- 预剪枝(Pre-pruning):在树生成过程中限制分裂(如限制深度、最小样本数)。
- 后剪枝(Post-pruning):生成完整树后,通过验证集删除对泛化性能无益的分支(如代价复杂度剪枝)。
四、算法对比与分析
| 算法 | 划分准则 | 任务类型 | 树结构 | 优缺点 |
|---|---|---|---|---|
| ID3 | 信息增益 | 分类 | 多叉树 |
易过拟合, 不支持连续特征 |
| C4.5 | 信息增益率 | 分类 | 多叉树 |
解决ID3的多值特征偏好, 计算复杂度高 |
| CART | 基尼不纯度 / MSE | 分类 + 回归 | 二叉树 |
支持回归任务, 可能生成深树需剪枝 |
五、示例推导(ID3算法)
数据集:14个样本,特征“天气”分为晴(5例,3正类)、阴(4例,4正类)、雨(5例,2正类)。
-
初始信息熵:
Ent(D)=−914log2914−514log2514≈0.940Ent(D)=−149log2149−145log2145≈0.940 -
按特征“天气”划分后的加权熵:
Ent加权=514×0.971+414×0+514×0.971≈0.693Ent加权=145×0.971+144×0+145×0.971≈0.693
六、Python代码实现
用Sklearn快速实现决策树分类:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
# 加载鸢尾花数据集
data = load_iris()
X, y = data.data, data.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建决策树模型(使用基尼不纯度)
clf = DecisionTreeClassifier(criterion='gini', max_depth=3)
clf.fit(X_train, y_train)
# 评估模型
print("测试集准确率:", clf.score(X_test, y_test))
项目实战
一、分类任务:鸢尾花分类(Iris Dataset)
1. 数据集说明
- 目标:根据花瓣和萼片的长度/宽度,预测鸢尾花种类(Setosa, Versicolor, Virginica)。
- 特征:4个数值型特征(萼片长、萼片宽、花瓣长、花瓣宽)。
- 标签:3个类别(0, 1, 2)
2.代码实现
# 导入库
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
# 加载数据集
data = load_iris()
X, y = data.data, data.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建决策树模型(使用基尼不纯度,限制最大深度为3)
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
clf.fit(X_train, y_train)
# 预测并评估
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"测试集准确率: {accuracy:.4f}") # 输出: 测试集准确率: 0.9556
# 可视化决策树
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True)
plt.title("鸢尾花分类决策树")
plt.show()
# 输出特征重要性
print("\n特征重要性排序:")
for name, importance in zip(data.feature_names, clf.feature_importances_):
print(f"{name}: {importance:.4f}")
"""
特征重要性排序:
sepal length (cm): 0.0000
sepal width (cm): 0.0000
petal length (cm): 0.0130
petal width (cm): 0.9870
"""
3.输出结果
- 决策树模型在测试集上准确率为 95.56%。
- 特征重要性显示花瓣宽度(petal width)是分类的关键特征。
二、回归任务:糖尿病预测(Diabetes Dataset)
1.数据集说明
- 目标:根据患者的生理指标(如BMI、血压等),预测糖尿病进展指标(数值型)。
- 特征:10个数值型特征。
- 标签:连续值(0~300)。
2.代码实现
# 导入库
from sklearn.datasets import load_diabetes
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
# 加载数据集
data = load_diabetes()
X, y = data.data, data.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建决策树回归模型
reg = DecisionTreeRegressor(max_depth=3, random_state=42)
reg.fit(X_train, y_train)
# 预测并评估
y_pred = reg.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"均方误差 (MSE): {mse:.2f}") # 输出: 均方误差 (MSE): 3793.11
print(f"R²分数: {r2:.4f}") # 输出: R²分数: 0.2720
# 输出特征重要性
print("\n特征重要性排序:")
for name, importance in zip(data.feature_names, reg.feature_importances_):
print(f"{name}: {importance:.4f}")
"""
特征重要性排序:
age: 0.0165
sex: 0.0000
bmi: 0.5877
bp: 0.1911
s1: 0.0000
s2: 0.0000
s3: 0.0000
s4: 0.0000
s5: 0.2057
s6: 0.0000
"""
3. 输出结果
- 均方误差(MSE)为 3793.11,R²分数为 0.2720(值越接近1,模型越好)。
- 关键特征为 BMI(bmi)和 血压(bp)。
三、关键代码解释
-
模型参数
criterion='gini':分类任务使用基尼不纯度。max_depth=3:限制树的最大深度,防止过拟合。random_state=42:固定随机种子,确保结果可复现。
-
可视化依赖
plot_tree函数依赖graphviz库,安装命令:-
pip install graphviz
-
-
特征重要性
feature_importances_属性表示每个特征对模型预测的贡献度,总和为1。
四、实战应用场景
- 分类任务:用户流失预测、垃圾邮件识别、疾病诊断。
- 回归任务:房价预测、销量预测、股票价格趋势分析。
通过调整 max_depth、min_samples_split 等参数,优化模型性能。
更多推荐



所有评论(0)