决策树原理及公式推导


一、决策树基本原理

决策树是一种基于树形结构的监督学习算法,通过递归划分特征空间实现分类或回归任务。其核心机制如下:

  1. 树形结构定义

    • 内部节点(Non-leaf Node)‌:表示对某一特征的判定规则(例如“特征Xj≤θXj​≤θ”)。
    • 分支(Branch)‌:对应特征判定的输出结果(例如“真”或“假”)。
    • 叶节点(Leaf Node)‌:输出最终预测值(分类任务的类别标签或回归任务的连续值)。
  2. 划分目标
    每次划分选择使子节点数据“纯度”最大化的特征及阈值。纯度的量化指标包括信息熵、基尼不纯度或均方误差(回归任务)。


二、数学原理与公式推导
  1. 信息熵(Information Entropy)
    信息熵度量数据集DD中类别分布的不确定性。

    Ent(D)=−∑k=1Kpklog⁡2pkEnt(D)=−k=1∑K​pk​log2​pk​
    • pkpk​:数据集中第kk类样本的比例。
    • 当所有样本属于同一类时,Ent(D)=0Ent(D)=0;类别分布越均匀,熵值越大。
  2. 信息增益(Information Gain)
    信息增益衡量使用特征XX划分数据集后纯度的提升量,用于ID3算法。

    Gain(D,X)=Ent(D)−∑i=1V∣Di∣∣D∣⋅Ent(Di)Gain(D,X)=Ent(D)−i=1∑V​∣D∣∣Di​∣​⋅Ent(Di​)
    • VV:特征XX的可能取值数。
    • ∣Di∣∣Di​∣:特征XX取第ii个值时对应的子集样本数。
    • 缺点:倾向选择取值较多的特征(需通过信息增益率修正)。
  3. 基尼不纯度(Gini Impurity)
    基尼不纯度衡量数据集DD中随机抽样样本类别不一致的概率,用于CART算法。

    Gini(D)=1−∑k=1Kpk2Gini(D)=1−k=1∑K​pk2​
    • 基尼不纯度越小,数据纯度越高。
    • 与信息熵相比,计算效率更高(无需对数运算)。
  4. CART回归树的划分准则
    回归任务中,使用均方误差(MSE)或平均绝对误差(MAE)作为划分指标。

    MSE(D)=1∣D∣∑i∈D(yi−yˉ)2MSE(D)=∣D∣1​i∈D∑​(yi​−yˉ​)2
    • yˉyˉ​:数据集DD中样本输出的均值。

三、算法流程与关键技术
  1. 特征选择策略

    • ID3算法‌:最大化信息增益,仅支持分类任务。
    • C4.5算法‌:引入信息增益率(Gain_Ratio=Gain(D,X)IV(X)Gain_Ratio=IV(X)Gain(D,X)​),其中IV(X)IV(X)为特征XX的固有值(Intrinsic Value),缓解对多值特征的偏好。
    • CART算法‌:
      • 分类任务:最小化基尼不纯度。
      • 回归任务:最小化均方误差。
      • 支持二叉树结构,递归生成左右子树。
  2. 递归终止条件

    • 节点中样本数小于预设阈值。
    • 节点纯度达到要求(如基尼不纯度低于阈值)。
    • 达到树的最大深度限制。
  3. 剪枝优化(Pruning)

    • 预剪枝(Pre-pruning)‌:在树生成过程中限制分裂(如限制深度、最小样本数)。
    • 后剪枝(Post-pruning)‌:生成完整树后,通过验证集删除对泛化性能无益的分支(如代价复杂度剪枝)。

四、算法对比与分析
算法 划分准则 任务类型 树结构 优缺点
ID3 信息增益 分类 多叉树

易过拟合,

不支持连续特征

C4.5 信息增益率 分类 多叉树

解决ID3的多值特征偏好,

计算复杂度高

CART 基尼不纯度 / MSE 分类 + 回归 二叉树

支持回归任务,

可能生成深树需剪枝


五、示例推导(ID3算法)

数据集‌:14个样本,特征“天气”分为晴(5例,3正类)、阴(4例,4正类)、雨(5例,2正类)。

  1. 初始信息熵‌:

    Ent(D)=−914log⁡2914−514log⁡2514≈0.940Ent(D)=−149​log2​149​−145​log2​145​≈0.940
  2. 按特征“天气”划分后的加权熵‌:

    Ent加权=514×0.971+414×0+514×0.971≈0.693Ent加权​=145​×0.971+144​×0+145​×0.971≈0.693

六、Python代码实现

用Sklearn快速实现决策树分类:

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

# 加载鸢尾花数据集
data = load_iris()
X, y = data.data, data.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树模型(使用基尼不纯度)
clf = DecisionTreeClassifier(criterion='gini', max_depth=3)
clf.fit(X_train, y_train)

# 评估模型
print("测试集准确率:", clf.score(X_test, y_test))

项目实战

一、分类任务:鸢尾花分类(Iris Dataset)

 1. 数据集说明

  • 目标‌:根据花瓣和萼片的长度/宽度,预测鸢尾花种类(Setosa, Versicolor, Virginica)。
  • 特征‌:4个数值型特征(萼片长、萼片宽、花瓣长、花瓣宽)。
  • 标签‌:3个类别(0, 1, 2)

2.代码实现

# 导入库
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 加载数据集
data = load_iris()
X, y = data.data, data.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树模型(使用基尼不纯度,限制最大深度为3)
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
clf.fit(X_train, y_train)

# 预测并评估
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"测试集准确率: {accuracy:.4f}")  # 输出: 测试集准确率: 0.9556

# 可视化决策树
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=data.feature_names, class_names=data.target_names, filled=True)
plt.title("鸢尾花分类决策树")
plt.show()

# 输出特征重要性
print("\n特征重要性排序:")
for name, importance in zip(data.feature_names, clf.feature_importances_):
    print(f"{name}: {importance:.4f}")
"""
特征重要性排序:
sepal length (cm): 0.0000
sepal width (cm): 0.0000
petal length (cm): 0.0130
petal width (cm): 0.9870
"""

3.输出结果

  • 决策树模型在测试集上准确率为 ‌95.56%‌。
  • 特征重要性显示‌花瓣宽度(petal width)‌是分类的关键特征。

二、回归任务:糖尿病预测(Diabetes Dataset)

1.数据集说明

  • 目标‌:根据患者的生理指标(如BMI、血压等),预测糖尿病进展指标(数值型)。
  • 特征‌:10个数值型特征。
  • 标签‌:连续值(0~300)。

2.代码实现

# 导入库
from sklearn.datasets import load_diabetes
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

# 加载数据集
data = load_diabetes()
X, y = data.data, data.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树回归模型
reg = DecisionTreeRegressor(max_depth=3, random_state=42)
reg.fit(X_train, y_train)

# 预测并评估
y_pred = reg.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"均方误差 (MSE): {mse:.2f}")  # 输出: 均方误差 (MSE): 3793.11
print(f"R²分数: {r2:.4f}")          # 输出: R²分数: 0.2720

# 输出特征重要性
print("\n特征重要性排序:")
for name, importance in zip(data.feature_names, reg.feature_importances_):
    print(f"{name}: {importance:.4f}")
"""
特征重要性排序:
age: 0.0165
sex: 0.0000
bmi: 0.5877
bp: 0.1911
s1: 0.0000
s2: 0.0000
s3: 0.0000
s4: 0.0000
s5: 0.2057
s6: 0.0000
"""

3. 输出结果

  • 均方误差(MSE)为 ‌3793.11‌,R²分数为 ‌0.2720‌(值越接近1,模型越好)。
  • 关键特征为 ‌BMI(bmi)‌和 ‌血压(bp)‌。

三、关键代码解释

  1. 模型参数

    • criterion='gini':分类任务使用基尼不纯度。
    • max_depth=3:限制树的最大深度,防止过拟合。
    • random_state=42:固定随机种子,确保结果可复现。
  2. 可视化依赖

    plot_tree 函数依赖 graphviz 库,安装命令:
    • pip install graphviz
      
  3. 特征重要性

    feature_importances_ 属性表示每个特征对模型预测的贡献度,总和为1。

四、实战应用场景

  1. 分类任务‌:用户流失预测、垃圾邮件识别、疾病诊断。
  2. 回归任务‌:房价预测、销量预测、股票价格趋势分析。

通过调整 max_depthmin_samples_split 等参数,优化模型性能。

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐