sklearn决策树

来源:互联网 发布:刷游戏的软件 编辑:程序博客网 时间:2024/05/21 12:46

sklearn 决策树

sklearn原文

目录

  • sklearn 决策树
      • 目录
      • 决策树优点
      • 决策树缺点
  • 分类问题
      • DecisionTreeRegression 使用方式类似
    • 复杂度
    • 实际使用时的小技巧
    • 决策树算法ID3 C45C50 and CART
    • 数学公式
      • 分类问题
        • Gini函数
        • 交叉熵
        • 错分类
      • 回归问题

决策树是一种非参数形式的有监督学习方法。常用于分类和回归任务,基本方法是从训练集中学习决策规则用于新样本的预测。

决策树优点

  • 结构简单易懂,树状结构便于可视化
  • 需要少量的数据准备,与其他的模型不同,不需要标准化。但是需要注意的是这个模型不能直接处理缺失值
  • The cost of using the tree is logarithmic in the number of data points used to train the tree (O(logN)的复杂度)
  • 可以处理数值型和目录型的数据。其他的模型通常只能处理单一类型的数据
  • 可以处理多输出问题
  • 白盒子 模型,解释性强
  • 可能可以使用统计检验方法验证模型,可以考虑可靠性

决策树缺点

  • 模型可能会产出多与复杂的树型结构,导致泛华能力不强。即所谓过拟合,一般使用剪枝策略或者设置最大深度来解决这种问题
  • 树的结构可能不稳定,特征上微小的差别可能导致结果不稳定,可以使用集成(Ensemble)的方式解决
  • 找到完全合适的树型结构是NP完全问题,贪心策略可能导致找到局部最优解而不是全局最优解,可以考虑使用随机森林的方式解决
  • 有一些问题无法用数型结构合适的描述出来XOR

分类问题

函数DecisonTreeClassifier 可用于多分类问题

参数为[n_samples,n_features] (可缺省)

>from sklearn import tree>X = [[0,0],[1,1]]>Y = [0,1]>clf = tree.DecisionTreeClassifier()>clf.fit(X,Y)

你和结束后,送入新的样本进行预测。

>clf.predict([[2.,2.]])array([1])#predict with probility>clf.predict_proba([[2.,2.]])array([[0.,1.]])

此函数在二分类和多分类情形均可使用。

Using the Iris dataset, we can construct a tree

>from sklearn.datasets import load_iris>from sklearn import tree>iris = load_iris()>clf = tree.DecisionTreeClassifier()>clf = clf.fit(iris.data, iris.target)

Once trained, we can export the tree in Graphviz format using the export_graphviz exporter.

>import graphviz>dot_data = tree.export_graphviz(clf, out_file =None, feature_names = iris.feature_names, class_names = iris.target_names,filled_True, rounded=True)>graph = graphviz.Source(dot_data)>graph

这里写图片描述

DecisionTreeRegression 使用方式类似


复杂度

通常,构建一个平衡的二叉树的花费是O(nsamplenfeatureslog(nsamples)) 查询时间为O(log(nsamples)).

虽然构建树时,我们总想构建平衡二叉树,但是事实上算法决定树的结构不总是平衡的。

而sklearn中的方法,根据交叉熵信息增益的方式可以让每个节点的复杂度达到O(nfeatureslog(nsamples)),从而达到理想的复杂度。


实际使用时的小技巧

  • 特征数量过多时,模型容易过拟合,需要充足的数据进行训练
  • 在直接使用决策树模型之前使用降维方法(PCA, ICA or Feature selection)更有可能找到一个好的特征
  • 可以使用export函数对树的结构进行可视化,可以先设置max_depth=3来观察树型结构是否合适与对象数据集,之后再增加深度
  • 树越深需要的数据越多,所以最好设置max_depth属性防止过拟合
  • 使用min_samples_leaf=n属性设置,必须有n个实例的叶子节点才会被保留,防止过拟合,一般初始化为min_samples_leaf =5
  • 对于带权重的样本上一条中的个数 概念可以推广为min_weight_fraction_leaf 即必须达到一定的权重和
  • 输入的数据会被拷贝为np.float32格式
  • 对于稀疏矩阵,推荐先使用csc_matrix before fitting and csr_matrix before predicting

决策树算法:ID3, C4.5,C5.0 and CART

In ENG

ID3 (Iterative Dichotomiser 3) was developed in 1986 by Ross Quinlan. The algorithm creates a multiway tree, finding for each node the categorical feature that will yield the largest information gain for categorical targets. Trees are grown ti their maximum size and then a pruning step is usually applied to improve the ability of the tree to generalize to unseen data.
(key: Information gain, entropy)

C4.5  is the successor to ID3 and removed the restriction that features must be categorical by dynamically defining a discrete attribute (based on numerical variables) that partitions the continuous attribute value into a discrete set of intervals. C4.5 converts the trained trees (i.e. the output of the ID3 algorithm) into sets of if-then rules. These accuracy of each rule is then evaluated to determine the order in which they should be applied. Pruning is done by removing a rule’s precondition if the accuracy of the rule improves without it.

C5.0 is Quinlan’s latest version release under a proprietary license. It uses less memory and builds smaller rulesets than C4.5 while being more accurate.

CART (Classification and Regression Trees) is very similar to C4.5, but it differs in that it supports numerical target variables (regression) and does not compute rule sets. CART constructs binary trees using the feature and threshold that yield the largest information gain at each node.

scikit-learn uses an optimised version of the CART algorithm.


数学公式

对于给定的训练向量xiRn,i=1,...,l  和标签向量yRl 一个决策树递归的将空间划分来让有相同标签的样本被分到一起。

我们把m节点的数据表示为Q。 对于每一个候选样本,一个分割方法θ=(j,tm) 由属性j和阈值tm组成,将样本分到Qleft(θ) 和 Qright(θ)子集:

Qleft(θ)=(x,y)|xj<=tm

Qright(θ)=QQright

在节点m处的不纯度(impurity) 用 impurity function H()来计算,其具体的形式由具体任务决定。

G(Q,θ)=nleftNmH(Qleft(θ))+nrightNmH(Qright(θ))

选择使得不纯度最小的分割方式

θ=argminθG(Q,θ)

重复一直到达到最大深度。

分类问题 :

如果是0,1…,K-1的多分类问题,对于结点m,让Rm代表有Nm个取值的集合

pmk=1/NmxiRmI(yi=k)

是在结点m被分到K类里的比例
以下是通常的impurity 度量

Gini函数

H(Xm)=kpmk(1pmk)

交叉熵

H(X_m) = - \sum_k p_{mk} \log(p_{mk})

错分类

H(Xm)=1max(pmk)

回归问题:

If the target is a continuous value, then for node m, representing a region R_m with N_m observations, common criteria to minimise as for determining locations for future splits are Mean Squared Error, which minimizes the L2 error using mean values at terminal nodes, and Mean Absolute Error, which minimizes the L1 error using median values at terminal nodes.

Mean Squared Error:

cm=1NmiNmyi
H(Xm)=1NmiNm(yicm)2

Mean Absolute Error:

ym¯=1NmiNmyi
H(Xm)=1NmiNm|yiym¯|

where X_m is the training data in node m

转载翻译自:http://scikit-learn.org/stable/modules/tree.html