CART分类算法

来源:互联网 发布:k8陀螺仪软件下载 编辑:程序博客网 时间:2024/06/05 14:30

统计学习方法是一本好书,可惜例子太少。找到一个好的CART算法的例子

谢谢原文作者了

http://www.cnblogs.com/zhangchaoyang  作者:Orisun

分类回归树(CART,Classification And Regression Tree)也属于一种决策树,上回文我们介绍了基于ID3算法的决策树。作为上篇,这里只介绍CART是怎样用于分类的。

分类回归树是一棵二叉树,且每个非叶子节点都有两个孩子,所以对于第一棵子树其叶子节点数比非叶子节点数多1。

表1

名称体温表面覆盖胎生产蛋能飞水生有腿冬眠类标记人恒温毛发是否否否是否哺乳类巨蟒冷血鳞片否是否否否是爬行类鲑鱼冷血鳞片否是否是否否鱼类鲸恒温毛发是否否是否否哺乳类蛙冷血无否是否有时是是两栖类巨蜥冷血鳞片否是否否是否爬行类蝙蝠恒温毛发是否是否是否哺乳类猫恒温皮是否否否是否哺乳类豹纹鲨冷血鳞片是否否是否否鱼类海龟冷血鳞片否是否有时是否爬行类豪猪恒温刚毛是否否否是是哺乳类鳗冷血鳞片否是否是否否鱼类蝾螈冷血无否是否有时是是两栖类

上例是属性有8个,每个属性又有多少离散的值可取。在决策树的每一个节点上我们可以按任一个属性的任一个值进行划分。比如最开始我们按:

1)表面覆盖为毛发和非毛发

2)表面覆盖为鳞片和非鳞片

3)体温为恒温和非恒温

等等产生当前节点的左右两个孩子。按哪种划分最好呢?有3个标准可以用来衡量划分的好坏:GINI指数、双化指数、有序双化指数。下面我们只讲GINI指数。

GINI指数

总体内包含的类别越杂乱,GINI指数就越大(跟熵的概念很相似)。比如体温为恒温时包含哺乳类5个、鸟类2个,则:

体温为非恒温时包含爬行类3个、鱼类3个、两栖类2个,则

所以如果按照“体温为恒温和非恒温”进行划分的话,我们得到GINI的增益(类比信息增益):

最好的划分就是使得GINI_Gain最小的划分。

终止条件

一个节点产生左右孩子后,递归地对左右孩子进行划分即可产生分类回归树。这里的终止条件是什么?什么时候节点就可以停止分裂了?直观的情况,当节点包含的数据记录都属于同一个类别时就可以终止分裂了。这只是一个特例,更一般的情况我们计算χ2值来判断分类条件和类别的相关程度,当χ2很小时说明分类条件和类别是独立的,即按照该分类条件进行分类是没有道理的,此时节点停止分裂。注意这里的“分类条件”是指按照GINI_Gain最小原则得到的“分类条件”。

假如在构造分类回归树的第一步我们得到的“分类条件”是:体温为恒温和非恒温。此时:

 哺乳类爬行类鱼类鸟类两栖类恒温50020非恒温03302

我在《独立性检验》中讲述了χ2的计算方法。当选定置信水平后查表可得“体温”与动物类别是否相互独立。

还有一种方式就是,如果某一分支覆盖的样本的个数如果小于一个阈值,那么也可产生叶子节点,从而终止Tree-Growth。

剪枝

当分类回归树划分得太细时,会对噪声数据产生过拟合作用。因此我们要通过剪枝来解决。剪枝又分为前剪枝和后剪枝:前剪枝是指在构造树的过程中就知道哪些节点可以剪掉,于是干脆不对这些节点进行分裂,在N皇后问题和背包问题中用的都是前剪枝,上面的χ2方法也可以认为是一种前剪枝;后剪枝是指构造出完整的决策树之后再来考查哪些子树可以剪掉。

在分类回归树中可以使用的后剪枝方法有多种,比如:代价复杂性剪枝、最小误差剪枝、悲观误差剪枝等等。这里我们只介绍代价复杂性剪枝法。

对于分类回归树中的每一个非叶子节点计算它的表面误差率增益值α。

是子树中包含的叶子节点个数;

是节点t的误差代价,如果该节点被剪枝;

r(t)是节点t的误差率;

p(t)是节点t上的数据占所有数据的比例。

是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。

比如有个非叶子节点t4如图所示:

已知所有的数据总共有60条,则节点t4的节点误差代价为:

子树误差代价为:

以t4为根节点的子树上叶子节点有3个,最终:

找到α值最小的非叶子节点,令其左右孩子为NULL。当多个非叶子节点的α值同时达到最小时,取最大的进行剪枝。

[cpp] view plaincopy
  1. #include<iostream>  
  2. #include<fstream>  
  3. #include<sstream>  
  4. #include<string>  
  5. #include<map>  
  6. #include<list>  
  7. #include<set>  
  8. #include<queue>  
  9. #include<utility>  
  10. #include<vector>  
  11. #include<cmath>  
  12.    
  13. using namespace std;  
  14.    
  15. //置信水平取0.95时的卡方表  
  16. const double CHI[18]={0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962};  
  17. /*根据多维数组计算卡方值*/  
  18. template<typename Comparable>  
  19. double cal_chi(Comparable **arr,int row,int col){  
  20.     vector<Comparable> rowsum(row);  
  21.     vector<Comparable> colsum(col);  
  22.     Comparable totalsum=static_cast<Comparable>(0);  
  23.     //cout<<"observation"<<endl;  
  24.     for(int i=0;i<row;++i){  
  25.         for(int j=0;j<col;++j){  
  26.             //cout<<arr[i][j]<<"\t";  
  27.             totalsum+=arr[i][j];  
  28.             rowsum[i]+=arr[i][j];  
  29.             colsum[j]+=arr[i][j];  
  30.         }  
  31.         //cout<<endl;  
  32.     }  
  33.     double rect=0.0;  
  34.     //cout<<"exception"<<endl;  
  35.     for(int i=0;i<row;++i){  
  36.         for(int j=0;j<col;++j){  
  37.             double excep=1.0*rowsum[i]*colsum[j]/totalsum;  
  38.             //cout<<excep<<"\t";  
  39.             if(excep!=0)  
  40.                 rect+=pow(arr[i][j]-excep,2.0)/excep;  
  41.         }  
  42.         //cout<<endl;  
  43.     }  
  44.     return rect;  
  45. }  
  46.    
  47. class MyTriple{  
  48. public:  
  49.     double first;  
  50.     int second;  
  51.     int third;  
  52.     MyTriple(){  
  53.         first=0.0;  
  54.         second=0;  
  55.         third=0;  
  56.     }  
  57.     MyTriple(double f,int s,int t):first(f),second(s),third(t){}  
  58.     bool operator< (const MyTriple &obj) const{  
  59.         int cmp=this->first-obj.first;  
  60.         if(cmp>0)  
  61.             return false;  
  62.         else if(cmp<0)  
  63.             return true;  
  64.         else{  
  65.             cmp=obj.second-this->second;  
  66.             if(cmp<0)  
  67.                 return true;  
  68.             else  
  69.                 return false;  
  70.         }  
  71.     }  
  72. };  
  73.    
  74. typedef map<string,int> MAP_REST_COUNT;  
  75. typedef map<string,MAP_REST_COUNT> MAP_ATTR_REST;  
  76. typedef vector<MAP_ATTR_REST> VEC_STATI;  
  77.    
  78. const int ATTR_NUM=8;       //自变量的维度  
  79. vector<string> X(ATTR_NUM);  
  80. int rest_number;        //因变量的种类数,即类别数  
  81. vector<pair<string,int> > classes;      //把类别、对应的记录数存放在一个数组中  
  82. int total_record_number;        //总的记录数  
  83. vector<vector<string> > inputData;      //原始输入数据  
  84.    
  85. class node{  
  86. public:  
  87.     node* parent;       //父节点  
  88.     node* leftchild;        //左孩子节点  
  89.     node* rightchild;       //右孩子节点  
  90.     string cond;        //分枝条件  
  91.     string decision;        //在该节点上作出的类别判定  
  92.     double precision;       //判定的正确率  
  93.     int record_number;      //该节点上涵盖的记录个数  
  94.     int size;       //子树包含的叶子节点的数目  
  95.     int index;      //层次遍历树,给节点标上序号  
  96.     double alpha;   //表面误差率的增加量  
  97.     node(){  
  98.         parent=NULL;  
  99.         leftchild=NULL;  
  100.         rightchild=NULL;  
  101.         precision=0.0;  
  102.         record_number=0;  
  103.         size=1;  
  104.         index=0;  
  105.         alpha=1.0;  
  106.     }  
  107.     node(node* p){  
  108.         parent=p;  
  109.         leftchild=NULL;  
  110.         rightchild=NULL;  
  111.         precision=0.0;  
  112.         record_number=0;  
  113.         size=1;  
  114.         index=0;  
  115.         alpha=1.0;  
  116.     }  
  117.     node(node* p,string c,string d):cond(c),decision(d){  
  118.         parent=p;  
  119.         leftchild=NULL;  
  120.         rightchild=NULL;  
  121.         precision=0.0;  
  122.         record_number=0;  
  123.         size=1;  
  124.         index=0;  
  125.         alpha=1.0;  
  126.     }  
  127.     void printInfo(){  
  128.         cout<<"index:"<<index<<"\tdecisoin:"<<decision<<"\tprecision:"<<precision<<"\tcondition:"<<cond<<"\tsize:"<<size;  
  129.         if(parent!=NULL)  
  130.             cout<<"\tparent index:"<<parent->index;  
  131.         if(leftchild!=NULL)  
  132.             cout<<"\tleftchild:"<<leftchild->index<<"\trightchild:"<<rightchild->index;  
  133.         cout<<endl;  
  134.     }  
  135.     void printTree(){  
  136.         printInfo();  
  137.         if(leftchild!=NULL)  
  138.             leftchild->printTree();  
  139.         if(rightchild!=NULL)  
  140.             rightchild->printTree();  
  141.     }  
  142. };  
  143.    
  144. int readInput(string filename){  
  145.     ifstream ifs(filename.c_str());  
  146.     if(!ifs){  
  147.         cerr<<"open inputfile failed!"<<endl;  
  148.         return -1;  
  149.     }  
  150.     map<string,int> catg;  
  151.     string line;  
  152.     getline(ifs,line);  
  153.     string item;  
  154.     istringstream strstm(line);  
  155.     strstm>>item;  
  156.     for(int i=0;i<X.size();++i){  
  157.         strstm>>item;  
  158.         X[i]=item;  
  159.     }  
  160.     while(getline(ifs,line)){  
  161.         vector<string> conts(ATTR_NUM+2);  
  162.         istringstream strstm(line);  
  163.         //strstm.str(line);  
  164.         for(int i=0;i<conts.size();++i){  
  165.             strstm>>item;  
  166.             conts[i]=item;  
  167.             if(i==conts.size()-1)  
  168.                 catg[item]++;  
  169.         }  
  170.         inputData.push_back(conts);  
  171.     }  
  172.     total_record_number=inputData.size();  
  173.     ifs.close();  
  174.     map<string,int>::const_iterator itr=catg.begin();  
  175.     while(itr!=catg.end()){  
  176.         classes.push_back(make_pair(itr->first,itr->second));  
  177.         itr++;  
  178.     }  
  179.     rest_number=classes.size();  
  180.     return 0;  
  181. }  
  182.    
  183. /*根据inputData作出一个统计stati*/  
  184. void statistic(vector<vector<string> > &inputData,VEC_STATI &stati){  
  185.     for(int i=1;i<ATTR_NUM+1;++i){  
  186.         MAP_ATTR_REST attr_rest;  
  187.         for(int j=0;j<inputData.size();++j){  
  188.             string attr_value=inputData[j][i];  
  189.             string rest=inputData[j][ATTR_NUM+1];  
  190.             MAP_ATTR_REST::iterator itr=attr_rest.find(attr_value);  
  191.             if(itr==attr_rest.end()){  
  192.                 MAP_REST_COUNT rest_count;  
  193.                 rest_count[rest]=1;  
  194.                 attr_rest[attr_value]=rest_count;  
  195.             }  
  196.             else{  
  197.                 MAP_REST_COUNT::iterator iter=(itr->second).find(rest);  
  198.                 if(iter==(itr->second).end()){  
  199.                     (itr->second).insert(make_pair(rest,1));  
  200.                 }  
  201.                 else{  
  202.                     iter->second+=1;  
  203.                 }  
  204.             }  
  205.         }  
  206.         stati.push_back(attr_rest);  
  207.     }  
  208. }  
  209.    
  210. /*依据某条件作出分枝时,inputData被分成两部分*/  
  211. void splitInput(vector<vector<string> > &inputData,int fitIndex,string cond,vector<vector<string> > &LinputData,vector<vector<string> > &RinputData){  
  212.     for(int i=0;i<inputData.size();++i){  
  213.         if(inputData[i][fitIndex+1]==cond)  
  214.             LinputData.push_back(inputData[i]);  
  215.         else  
  216.             RinputData.push_back(inputData[i]);  
  217.     }  
  218. }  
  219.    
  220. void printStati(VEC_STATI &stati){  
  221.     for(int i=0;i<stati.size();i++){  
  222.         MAP_ATTR_REST::const_iterator itr=stati[i].begin();  
  223.         while(itr!=stati[i].end()){  
  224.             cout<<itr->first;  
  225.             MAP_REST_COUNT::const_iterator iter=(itr->second).begin();  
  226.             while(iter!=(itr->second).end()){  
  227.                 cout<<"\t"<<iter->first<<"\t"<<iter->second;  
  228.                 iter++;  
  229.             }  
  230.             itr++;  
  231.             cout<<endl;  
  232.         }  
  233.         cout<<endl;  
  234.     }  
  235. }  
  236.    
  237. void split(node *root,vector<vector<string> > &inputData,vector<pair<string,int> > classes){  
  238.     //root->printInfo();  
  239.     root->record_number=inputData.size();  
  240.     VEC_STATI stati;  
  241.     statistic(inputData,stati);  
  242.     //printStati(stati);  
  243.     //for(int i=0;i<rest_number;i++)  
  244.     //  cout<<classes[i].first<<"\t"<<classes[i].second<<"\t";  
  245.     //cout<<endl;  
  246.     /*找到最大化GINI指标的划分*/  
  247.     double minGain=1.0;     //最小的GINI增益  
  248.     int fitIndex=-1;  
  249.     string fitCond;  
  250.     vector<pair<string,int> > fitleftclasses;  
  251.     vector<pair<string,int> > fitrightclasses;  
  252.     int fitleftnumber;  
  253.     int fitrightnumber;  
  254.     for(int i=0;i<stati.size();++i){     //扫描每一个自变量  
  255.         MAP_ATTR_REST::const_iterator itr=stati[i].begin();  
  256.         while(itr!=stati[i].end()){         //扫描自变量上的每一个取值  
  257.             string condition=itr->first;     //判定的条件,即到达左孩子的条件  
  258.             //cout<<"cond 为"<<X[i]+condition<<"时:";  
  259.             vector<pair<string,int> > leftclasses(classes);     //左孩子节点上类别、及对应的数目  
  260.             vector<pair<string,int> > rightclasses(classes);    //右孩子节点上类别、及对应的数目  
  261.             int leftnumber=0;       //左孩子节点上包含的类别数目  
  262.             int rightnumber=0;      //右孩子节点上包含的类别数目  
  263.             for(int j=0;j<leftclasses.size();++j){       //更新类别对应的数目  
  264.                 string rest=leftclasses[j].first;  
  265.                 MAP_REST_COUNT::const_iterator iter2;  
  266.                 iter2=(itr->second).find(rest);  
  267.                 if(iter2==(itr->second).end()){      //没找到  
  268.                     leftclasses[j].second=0;  
  269.                     rightnumber+=rightclasses[j].second;  
  270.                 }  
  271.                 else{       //找到  
  272.                     leftclasses[j].second=iter2->second;  
  273.                     leftnumber+=leftclasses[j].second;  
  274.                     rightclasses[j].second-=(iter2->second);  
  275.                     rightnumber+=rightclasses[j].second;  
  276.                 }  
  277.             }  
  278.             /**if(leftnumber==0 || rightnumber==0){ 
  279.                 cout<<"左右有一边为空"<<endl; 
  280.                   
  281.                 for(int k=0;k<rest_number;k++) 
  282.                     cout<<leftclasses[k].first<<"\t"<<leftclasses[k].second<<"\t"; 
  283.                 cout<<endl; 
  284.                 for(int k=0;k<rest_number;k++) 
  285.                     cout<<rightclasses[k].first<<"\t"<<rightclasses[k].second<<"\t"; 
  286.                 cout<<endl; 
  287.             }**/  
  288.             double gain1=1.0;           //计算GINI增益  
  289.             double gain2=1.0;  
  290.             if(leftnumber==0)  
  291.                 gain1=0.0;  
  292.             else  
  293.                 for(int j=0;j<leftclasses.size();++j)         
  294.                     gain1-=pow(1.0*leftclasses[j].second/leftnumber,2.0);  
  295.             if(rightnumber==0)  
  296.                 gain2=0.0;  
  297.             else  
  298.                 for(int j=0;j<rightclasses.size();++j)  
  299.                     gain2-=pow(1.0*rightclasses[j].second/rightnumber,2.0);  
  300.             double gain=1.0*leftnumber/(leftnumber+rightnumber)*gain1+1.0*rightnumber/(leftnumber+rightnumber)*gain2;  
  301.             //cout<<"GINI增益:"<<gain<<endl;  
  302.             if(gain<minGain){  
  303.                 //cout<<"GINI增益:"<<gain<<"\t"<<i<<"\t"<<condition<<endl;  
  304.                 fitIndex=i;  
  305.                 fitCond=condition;  
  306.                 fitleftclasses=leftclasses;  
  307.                 fitrightclasses=rightclasses;  
  308.                 fitleftnumber=leftnumber;  
  309.                 fitrightnumber=rightnumber;  
  310.                 minGain=gain;  
  311.             }  
  312.             itr++;  
  313.         }  
  314.     }  
  315.    
  316.     /*计算卡方值,看有没有必要进行分裂*/  
  317.     //cout<<"按"<<X[fitIndex]+fitCond<<"划分,计算卡方"<<endl;  
  318.     int **arr=new int*[2];  
  319.     for(int i=0;i<2;i++)  
  320.         arr[i]=new int[rest_number];  
  321.     for(int i=0;i<rest_number;i++){  
  322.         arr[0][i]=fitleftclasses[i].second;  
  323.         arr[1][i]=fitrightclasses[i].second;  
  324.     }  
  325.     double chi=cal_chi(arr,2,rest_number);  
  326.     //cout<<"chi="<<chi<<" CHI="<<CHI[rest_number-2]<<endl;  
  327.     if(chi<CHI[rest_number-2]){      //独立,没必要再分裂了  
  328.         delete []arr[0];    delete []arr[1];    delete []arr;  
  329.         return;             //不需要分裂函数就返回  
  330.     }  
  331.     delete []arr[0];    delete []arr[1];    delete []arr;  
  332.        
  333.     /*分裂*/  
  334.     root->cond=X[fitIndex]+"="+fitCond;      //root的分枝条件  
  335.     //cout<<"分类条件:"<<root->cond<<endl;  
  336.     node *travel=root;      //root及其祖先节点的size都要加1  
  337.     while(travel!=NULL){  
  338.         (travel->size)++;  
  339.         travel=travel->parent;  
  340.     }  
  341.        
  342.     node *LChild=new node(root);        //创建左右孩子  
  343.     node *RChild=new node(root);  
  344.     root->leftchild=LChild;  
  345.     root->rightchild=RChild;  
  346.     int maxLcount=0;  
  347.     int maxRcount=0;  
  348.     string Ldicision,Rdicision;  
  349.     for(int i=0;i<rest_number;++i){      //统计哪种类别出现的最多,从而作出类别判定  
  350.         if(fitleftclasses[i].second>maxLcount){  
  351.             maxLcount=fitleftclasses[i].second;  
  352.             Ldicision=fitleftclasses[i].first;  
  353.         }  
  354.         if(fitrightclasses[i].second>maxRcount){  
  355.             maxRcount=fitrightclasses[i].second;  
  356.             Rdicision=fitrightclasses[i].first;  
  357.         }  
  358.     }  
  359.     LChild->decision=Ldicision;  
  360.     RChild->decision=Rdicision;  
  361.     LChild->precision=1.0*maxLcount/fitleftnumber;  
  362.     RChild->precision=1.0*maxRcount/fitrightnumber;  
  363.        
  364.     /*递归对左右孩子进行分裂*/  
  365.     vector<vector<string> > LinputData,RinputData;  
  366.     splitInput(inputData,fitIndex,fitCond,LinputData,RinputData);  
  367.     //cout<<"左边inputData行数:"<<LinputData.size()<<endl;  
  368.     //cout<<"右边inputData行数:"<<RinputData.size()<<endl;  
  369.     split(LChild,LinputData,fitleftclasses);  
  370.     split(RChild,RinputData,fitrightclasses);  
  371. }  
  372.    
  373. /*计算子树的误差代价*/  
  374. double calR2(node *root){  
  375.     if(root->leftchild==NULL)  
  376.         return (1-root->precision)*root->record_number/total_record_number;  
  377.     else  
  378.         return calR2(root->leftchild)+calR2(root->rightchild);  
  379. }  
  380.    
  381. /*层次遍历树,给节点标上序号。同时计算alpha*/  
  382. void index(node *root,priority_queue<MyTriple> &pq){  
  383.     int i=1;  
  384.     queue<node*> que;  
  385.     que.push(root);  
  386.     while(!que.empty()){  
  387.         node* n=que.front();  
  388.         que.pop();  
  389.         n->index=i++;  
  390.         if(n->leftchild!=NULL){  
  391.             que.push(n->leftchild);  
  392.             que.push(n->rightchild);  
  393.             //计算表面误差率的增量  
  394.             double r1=(1-n->precision)*n->record_number/total_record_number;      //节点的误差代价  
  395.             double r2=calR2(n);  
  396.             n->alpha=(r1-r2)/(n->size-1);  
  397.             pq.push(MyTriple(n->alpha,n->size,n->index));  
  398.         }  
  399.     }  
  400. }  
  401.    
  402. /*剪枝*/  
  403. void prune(node *root,priority_queue<MyTriple> &pq){  
  404.     MyTriple triple=pq.top();  
  405.     int i=triple.third;  
  406.     queue<node*> que;  
  407.     que.push(root);  
  408.     while(!que.empty()){  
  409.         node* n=que.front();  
  410.         que.pop();  
  411.         if(n->index==i){  
  412.             cout<<"将要剪掉"<<i<<"的左右子树"<<endl;  
  413.             n->leftchild=NULL;  
  414.             n->rightchild=NULL;  
  415.             int s=n->size-1;  
  416.             node *trav=n;  
  417.             while(trav!=NULL){  
  418.                 trav->size-=s;  
  419.                 trav=trav->parent;  
  420.             }  
  421.             break;  
  422.         }  
  423.         else if(n->leftchild!=NULL){  
  424.             que.push(n->leftchild);  
  425.             que.push(n->rightchild);  
  426.         }  
  427.     }  
  428. }  
  429.    
  430. void test(string filename,node *root){  
  431.     ifstream ifs(filename.c_str());  
  432.     if(!ifs){  
  433.         cerr<<"open inputfile failed!"<<endl;  
  434.         return;  
  435.     }  
  436.     string line;  
  437.     getline(ifs,line);  
  438.     string item;  
  439.     istringstream strstm(line);     //跳过第一行  
  440.     map<string,string> independent;       //自变量,即分类的依据  
  441.     while(getline(ifs,line)){  
  442.         istringstream strstm(line);  
  443.         //strstm.str(line);  
  444.         strstm>>item;  
  445.         cout<<item<<"\t";  
  446.         for(int i=0;i<ATTR_NUM;++i){  
  447.             strstm>>item;  
  448.             independent[X[i]]=item;  
  449.         }  
  450.         node *trav=root;  
  451.         while(trav!=NULL){  
  452.             if(trav->leftchild==NULL){  
  453.                 cout<<(trav->decision)<<"\t置信度:"<<(trav->precision)<<endl;;  
  454.                 break;  
  455.             }  
  456.             string cond=trav->cond;  
  457.             string::size_type pos=cond.find("=");  
  458.             string pre=cond.substr(0,pos);  
  459.             string post=cond.substr(pos+1);  
  460.             if(independent[pre]==post)  
  461.                 trav=trav->leftchild;  
  462.             else  
  463.                 trav=trav->rightchild;  
  464.         }  
  465.     }  
  466.     ifs.close();  
  467. }  
  468.    
  469. int main(){  
  470.     string inputFile="animal";  
  471.     readInput(inputFile);  
  472.     VEC_STATI stati;        //最原始的统计  
  473.     statistic(inputData,stati);  
  474.        
  475. //  for(int i=0;i<classes.size();++i)  
  476. //      cout<<classes[i].first<<"\t"<<classes[i].second<<"\t";  
  477. //  cout<<endl;  
  478.     node *root=new node();  
  479.     split(root,inputData,classes);      //分裂根节点  
  480.     priority_queue<MyTriple> pq;  
  481.     index(root,pq);  
  482.     root->printTree();  
  483.     cout<<"剪枝前使用该决策树最多进行"<<root->size-1<<"次条件判断"<<endl;  
  484.     /** 
  485.     //检验一个是不是表面误差增量最小的被剪掉了 
  486.     while(!pq.empty()){ 
  487.         MyTriple triple=pq.top(); 
  488.         pq.pop(); 
  489.         cout<<triple.first<<"\t"<<triple.second<<"\t"<<triple.third<<endl; 
  490.     } 
  491.     **/  
  492.     test(inputFile,root);  
  493.        
  494.     prune(root,pq);  
  495.     cout<<"剪枝后使用该决策树最多进行"<<root->size-1<<"次条件判断"<<endl;  
  496.     test(inputFile,root);  
  497.     return 0;  
  498. }  

原创粉丝点击