决策树简介
决策树是一种特殊的树形结构,一般由节点和有向边组成。其中,节点表示特征、属性或者一个类。而有向边包含有判断条件。如图所示,决策树从根节点开始延伸,经过不同的判断条件后,到达不同的子节点。而上层子节点又可以作为父节点被进一步划分为下层子节点。一般情况下,我们从根节点输入数据,经过多次判断后,这些数据就会被分为不同的类别。这就构成了一颗简单的分类决策树。
特征选择
特征选择是建立决策树之前十分重要的一步。如果是随机地选择特征,那么所建立决策树的学习效率将会大打折扣。举例来讲,银行采用决策树来解决信用卡审批问题,判断是否向某人发放信用卡可以根据其年龄、工作单位、是否有不动产、历史信贷情况等特征决定。而选择不同的特征,后续生成的决策树就会不一致,这种不一致最终会影响到决策树的分类效率。
通常我们在选择特征时,会考虑到两种不同的指标,分别为:信息增益和信息增益比。要想弄清楚这两个概念,我们就不得不提到信息论中的另一个十分常见的名词 —— 熵。
熵(Entropy)是表示随机变量不确定性的度量。简单来讲,熵越大,随机变量的不确定性就越大。而特征 A 对于某一训练集 D 的信息增益 g(D, A) 定义为集合 D 的熵 H(D) 与特征 A 在给定条件下 D 的熵 H(D/A) 之差。
实现代码
from sklearn import datasets import matplotlib.pyplot as plt import numpy as np from sklearn import tree # Iris数据集是常用的分类实验数据集, # 由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集, # 是一类多重变量分析的数据集。数据集包含150个数据集, # 分为3类,每类50个数据,每个数据包含4个属性。 # 可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。 #载入数据集 iris = datasets.load_iris() iris_data=iris['data'] iris_label=iris['target'] iris_target_name=iris['target_names'] X=np.array(iris_data) Y=np.array(iris_label) #训练 clf=tree.DecisionTreeClassifier(max_depth=3) clf.fit(X,Y) #这里预测当前输入的值的所属分类 print('类别是',iris_target_name[clf.predict([[7,1,1,1]])[0]])