决策树(Decision Tree)
大约 3 分钟
决策树(Decision Tree)
概念
- 从根节点开始一步步走到叶子节点(决策)
- 所有的数据最终都会落到叶子节点,既可以做分类也可以做回归
- 树的组成
- 根节点:第一个选择点
- 非叶子节点与分支:中间过程
- 叶子节点:最终的决策结果
特征选择
- 衡量标准-熵
$$ P(X=x_i) = p_i, i=1,2, ... , n\ H(X) = - ∑ (p_i * lgp_i), i=1,2, ... , n $$
$p_i$为特征X的概率
该函数在0-1上先增后减
熵越大,数据越不纯
举例
- A集合[1,1,1,1,1,1,1,1,2,2] B集合[1,2,3,4,5,6,7,8,9,1]
- 显然A集合的熵值要低,因为A里面只有两种类别,相对稳定一些而B中类别太多了,熵值就会大很多
根节点选择标准
- 信息增益=原熵值-特征划分后熵值
- 增益越大,分类效果越好
算法
- ID3:其核心是在决策树的各级节点上,使用信息增益方法作为属性的选择标准,来帮助确定生成每个节点时所采用的合适属性
- C4.5:相对于ID3算法的重要改进是使用信息增益率来选择节点属性。C4.5算法可以克服ID3算法存在的不足:ID3算法只适用于离散的描述属性,而C4.5算法既能够处理离散的描述属性,也可以处理连续的描述属性
- CART:基尼指数,使用GINI系数来当做衡量标准,是一种十分有效的非参数分类和回归方法,通过构建树、修剪树、评估树来构建一个二叉树。当终结点是连续变量时为回归树;当终结点是分类变量时为分类树
- GINI系数:$ GINI(X) = = 1 - ∑ (p_i^2) $
连续值的处理
除了可以直接分类的特征,还有一些特征值是连续的,如:年龄、身高、体重等,这种情况下,属性的可取值无穷多,就无法直接划分节点,需要先将连续值离散化。
- 离散化:最简单的二分法策略对连续值进行处理,C4.5算法就是使用的这个方法
缺失值的处理
剪枝
决策树过拟合风险很大,理论上可以完全分得开数据,只要树足够庞大,叶子节点就会有一条数据
- 预剪枝:边建立决策树边进行剪枝的操作(更实用)
- 限制深度,叶子节点个数,叶子节点样本数,信息增益量等
- 后剪枝:当建立完决策树后来进行剪枝操作
- 通过一定的衡量标准 $ C_α(T) = C(T) + α * |T_{leaf}| $, 叶子节点越多,损失越大
sklearn 中使用
使用泰坦尼克号的数据举例说明
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import pandas as pd
def descision():
"""
决策树对泰坦尼克号进行预测生死
:return:None
"""
titan = pd.read_csv("http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt")
# 处理数据,找出特征值和目标值
x = titan[['pclass', 'age', 'sex']]
y = titan['survived']
print(x)
# 缺失值处理
x['age'].fillna(x['age'].mean(), inplace=True)
# 分割数据集到训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25)
# 进行处理(特征工程)
dict = DictVectorizer(sparse=False)
x_train = dict.fit_transform(x_train.to_dict(orient="records"))
print(dict.get_feature_names())
x_test = dict.transform(x_test.to_dict(orient="records"))
print(x_train)
# 用决策树进行预测
dec = DecisionTreeClassifier()
dec.fit(x_train, y_train)
# 预测准确率
print("预测的准确率为:", dec.score(x_test, y_test))
# 导出决策树的结构
export_graphviz(dec, out_file="./tree.dot", feature_names=['age', 'pclass=1st', 'pclass=2nd', 'pclass=3rd', 'sex=female', 'sex=male'])
if __name__=="__main__":
descision()
