OpenCV码源笔记——Decision Tree决策树

来源:互联网 发布:高清直播网络电视 编辑:程序博客网 时间:2024/06/18 13:45

来自OpenCV2.3.1 sample/c/mushroom.cpp

 

1.首先读入agaricus-lepiota.data的训练样本。

   样本中第一项是e或p代表有毒或无毒的标志位;其他是特征,可以把每个样本看做一个特征向量;

   cvSeqPush( seq, el_ptr );读入序列seq中,每一项都存储一个样本即特征向量;

   之后,把特征向量与标志位分别读入CvMat* data与CvMat* reponses中

   还有一个CvMat* missing保留丢失位当前小于0位置;

 

2.训练样本

    dtree = new CvDTree;    dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,    CvDTreeParams( 8, // max depth    10, // min sample count 样本数小于10时,停止分裂     0, // regression accuracy: N/A here;回归树的限制精度    true, // compute surrogate split, as we have missing data;;为真时,计算missing data和变量的重要性    15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义    10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation where K is equal to cv_folds    true, // use 1SE rule => smaller tree;If true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确    true, // throw away the pruned tree branches    priors //错分类的代价我们判断的:有毒VS无毒 错误的代价比 the array of priors, the bigger p_weight, the more attention    // to the poisonous mushrooms    // (a mushroom will be judjed to be poisonous with bigger chance)    ));


 

3.

double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;


4.interactive_classification通过人工输入特征来判断。

 

#include "opencv2/core/core_c.h"#include "opencv2/ml/ml.hpp"#include <stdio.h>void help(){printf("\nThis program demonstrated the use of OpenCV's decision tree function for learning and predicting data\n""Usage :\n""./mushroom <path to agaricus-lepiota.data>\n""\n""The sample demonstrates how to build a decision tree for classifying mushrooms.\n""It uses the sample base agaricus-lepiota.data from UCI Repository, here is the link:\n""\n""Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n""UCI Repository of machine learning databases\n""[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n""Irvine, CA: University of California, Department of Information and Computer Science.\n""\n""// loads the mushroom database, which is a text file, containing\n""// one training sample per row, all the input variables and the output variable are categorical,\n""// the values are encoded by characters.\n\n");}int mushroom_read_database( const char* filename, CvMat** data, CvMat** missing, CvMat** responses ){const int M = 1024;FILE* f = fopen( filename, "rt" );CvMemStorage* storage;CvSeq* seq;char buf[M+2], *ptr;float* el_ptr;CvSeqReader reader;int i, j, var_count = 0;if( !f )return 0;// read the first line and determine the number of variablesif( !fgets( buf, M, f )){fclose(f);return 0;}for( ptr = buf; *ptr != '\0'; ptr++ )var_count += *ptr == ',';//计算每个样本的数量,每个样本一个“,”,样本数量=var_count+1;assert( ptr - buf == (var_count+1)*2 );// create temporary memory storage to store the whole database//把样本存入seq中,存储空间是storage;el_ptr = new float[var_count+1];storage = cvCreateMemStorage();seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );//for(;;){for( i = 0; i <= var_count; i++ ){int c = buf[i*2];el_ptr[i] = c == '?' ? -1.f : (float)c;}if( i != var_count+1 )break;cvSeqPush( seq, el_ptr );if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )break;}fclose(f);// allocate the output matrices and copy the base there*data = cvCreateMat( seq->total, var_count, CV_32F );//行数:样本数量;列数:样本大小;*missing = cvCreateMat( seq->total, var_count, CV_8U );*responses = cvCreateMat( seq->total, 1, CV_32F );//样本标志;cvStartReadSeq( seq, &reader );for( i = 0; i < seq->total; i++ ){const float* sdata = (float*)reader.ptr + 1;float* ddata = data[0]->data.fl + var_count*i;float* dr = responses[0]->data.fl + i;uchar* dm = missing[0]->data.ptr + var_count*i;for( j = 0; j < var_count; j++ ){ddata[j] = sdata[j];dm[j] = sdata[j] < 0;}*dr = sdata[-1];//样本的第一个位置是标志;CV_NEXT_SEQ_ELEM( seq->elem_size, reader );}cvReleaseMemStorage( &storage );delete el_ptr;return 1;}CvDTree* mushroom_create_dtree( const CvMat* data, const CvMat* missing,const CvMat* responses, float p_weight ){CvDTree* dtree;CvMat* var_type;int i, hr1 = 0, hr2 = 0, p_total = 0;float priors[] = { 1, p_weight };var_type = cvCreateMat( data->cols + 1, 1, CV_8U );cvSet( var_type, cvScalarAll(CV_VAR_CATEGORICAL) ); // all the variables are categoricaldtree = new CvDTree;dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,CvDTreeParams( 8, // max depth10, // min sample count样本数小于10时,停止分裂0, // regression accuracy: N/A here;回归树的限制精度true, // compute surrogate split, as we have missing data;为真时,计算missing data和可变的重要性正确度15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation true, // use 1SE rule => smaller treeIf true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确true, // throw away the pruned tree branchespriors // the array of priors, the bigger p_weight, the more attention// to the poisonous mushrooms// (a mushroom will be judjed to be poisonous with bigger chance)));// compute hit-rate on the training database, demonstrates predict usage.for( i = 0; i < data->rows; i++ ){CvMat sample, mask;cvGetRow( data, &sample, i );cvGetRow( missing, &mask, i );double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;int d = fabs(r - responses->data.fl[i]) >= FLT_EPSILON;//大于阈值FLT_EPSILON被判断为误检if( d ){if( r != 'p' )hr1++;elsehr2++;}p_total += responses->data.fl[i] == 'p';}printf( "Results on the training database:\n""\tPoisonous mushrooms mis-predicted: %d (%g%%)\n""\tFalse-alarms: %d (%g%%)\n", hr1, (double)hr1*100/p_total,hr2, (double)hr2*100/(data->rows - p_total) );cvReleaseMat( &var_type );return dtree;}static const char* var_desc[] ={"cap shape (bell=b,conical=c,convex=x,flat=f)","cap surface (fibrous=f,grooves=g,scaly=y,smooth=s)","cap color (brown=n,buff=b,cinnamon=c,gray=g,green=r,\n\tpink=p,purple=u,red=e,white=w,yellow=y)","bruises? (bruises=t,no=f)","odor (almond=a,anise=l,creosote=c,fishy=y,foul=f,\n\tmusty=m,none=n,pungent=p,spicy=s)","gill attachment (attached=a,descending=d,free=f,notched=n)","gill spacing (close=c,crowded=w,distant=d)","gill size (broad=b,narrow=n)","gill color (black=k,brown=n,buff=b,chocolate=h,gray=g,\n\tgreen=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y)","stalk shape (enlarging=e,tapering=t)","stalk root (bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r)","stalk surface above ring (ibrous=f,scaly=y,silky=k,smooth=s)","stalk surface below ring (ibrous=f,scaly=y,silky=k,smooth=s)","stalk color above ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)","stalk color below ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)","veil type (partial=p,universal=u)","veil color (brown=n,orange=o,white=w,yellow=y)","ring number (none=n,one=o,two=t)","ring type (cobwebby=c,evanescent=e,flaring=f,large=l,\n\tnone=n,pendant=p,sheathing=s,zone=z)","spore print color (black=k,brown=n,buff=b,chocolate=h,green=r,\n\torange=o,purple=u,white=w,yellow=y)","population (abundant=a,clustered=c,numerous=n,\n\tscattered=s,several=v,solitary=y)","habitat (grasses=g,leaves=l,meadows=m,paths=p\n\turban=u,waste=w,woods=d)",0};void print_variable_importance( CvDTree* dtree, const char** var_desc ){const CvMat* var_importance = dtree->get_var_importance();int i;char input[1000];if( !var_importance ){printf( "Error: Variable importance can not be retrieved\n" );return;}printf( "Print variable importance information? (y/n) " );scanf( "%1s", input );if( input[0] != 'y' && input[0] != 'Y' )return;for( i = 0; i < var_importance->cols*var_importance->rows; i++ ){double val = var_importance->data.db[i];if( var_desc ){char buf[100];int len = strchr( var_desc[i], '(' ) - var_desc[i] - 1;strncpy( buf, var_desc[i], len );buf[len] = '\0';printf( "%s", buf );}elseprintf( "var #%d", i );printf( ": %g%%\n", val*100. );}}void interactive_classification( CvDTree* dtree, const char** var_desc ){char input[1000];const CvDTreeNode* root;CvDTreeTrainData* data;if( !dtree )return;root = dtree->get_root();data = dtree->get_data();for(;;){const CvDTreeNode* node;printf( "Start/Proceed with interactive mushroom classification (y/n): " );scanf( "%1s", input );if( input[0] != 'y' && input[0] != 'Y' )break;printf( "Enter 1-letter answers, '?' for missing/unknown value...\n" ); // custom version of predict//传统的预测方式;node = root;for(;;){CvDTreeSplit* split = node->split;int dir = 0;if( !node->left || node->Tn <= dtree->get_pruned_tree_idx() || !node->split )break;for( ; split != 0; ){int vi = split->var_idx, j;int count = data->cat_count->data.i[vi];const int* map = data->cat_map->data.i + data->cat_ofs->data.i[vi];printf( "%s: ", var_desc[vi] );scanf( "%1s", input );if( input[0] == '?' ){split = split->next;continue;}// convert the input character to the normalized value of the variablefor( j = 0; j < count; j++ )if( map[j] == input[0] )break;if( j < count ){dir = (split->subset[j>>5] & (1 << (j&31))) ? -1 : 1;if( split->inversed )dir = -dir;break;}elseprintf( "Error: unrecognized value\n" );}if( !dir ){printf( "Impossible to classify the sample\n");node = 0;break;}node = dir < 0 ? node->left : node->right;}if( node )printf( "Prediction result: the mushroom is %s\n",node->class_idx == 0 ? "EDIBLE" : "POISONOUS" );printf( "\n-----------------------------\n" );}}int main( int argc, char** argv ){CvMat *data = 0, *missing = 0, *responses = 0;CvDTree* dtree;const char* base_path = argc >= 2 ? argv[1] : "agaricus-lepiota.data";help();if( !mushroom_read_database( base_path, &data, &missing, &responses ) ){printf( "\nUnable to load the training database\n\n");help();return -1;}dtree = mushroom_create_dtree( data, missing, responses,10 // poisonous mushrooms will have 10x higher weight in the decision tree);cvReleaseMat( &data );cvReleaseMat( &missing );cvReleaseMat( &responses );print_variable_importance( dtree, var_desc );interactive_classification( dtree, var_desc );delete dtree;return 0;}


 

原创粉丝点击