c++版id3决策树实现

来源:互联网 发布:db2 分页 sql 编辑:程序博客网 时间:2024/04/29 06:39

上一篇文章实现了Python的决策树,借用上篇文章的算法思路实现了c++版的算法。数据结构是自己设计实现,肯定有很多不好的地方希望各位高手能给出些建议,这是我第一次使用c++来实现这样大的程序。程序中用到的数据是借用网上一哥们二的。

DataSet.h

#ifndef SAMPLE_H#define SAMPLE_H#include<vector>#include<string>#include<set>using namespace std;class DataSet{private:void ReadData(ifstream &in);//读取数据vector<string> SplitLine(const string &str);//处理从文件读入的每一行数据double Entropy(const vector<int> &v);public:struct Sample{vector<string> attributes;string targetAttributes;};//存储数据的元素vector<string> attributesNames; //变量名称vector<Sample> dataSet; //数据集合DataSet();//无参数构造函数DataSet(const string &fileName);//从文件构造数据集~DataSet();void Print();//打印数据集,用于直观显示存储的数据。double Gain(const string &featureName);//计算信息增益DataSet SplitDataSet(const string&featureName, const string &featureValue);string BestSplitFeature();//最大信息增益的属性string CommomTargetValue();bool IsSameTarget();int GetAttributeIndex(const string &attributeName);};#endif

DataSet.cpp

#include<vector>#include<iostream>#include<fstream>#include<string>#include<map>#include<set>#include<numeric>#include"DataSet.h"using namespace std;//默认构造函数DataSet::DataSet(){}//从文件构建数据集DataSet::DataSet(const string &fileName){ifstream in(fileName);if (!in){cout << "文件打开失败";}else{ReadData(in);}}//析构函数DataSet::~DataSet(){if (!attributesNames.empty()){ dataSet.clear(); attributesNames.clear();}}//读取数据void DataSet::ReadData(ifstream &in){string str;vector<string> tempV;getline(in, str);tempV= SplitLine(str);//调用splitLine;attributesNames.assign(tempV.begin() + 1, tempV.end() - 1);while (getline(in, str)){Sample s;tempV = SplitLine(str);s.attributes.assign(tempV.begin() + 1, tempV.end() - 1);s.targetAttributes = tempV[tempV.size() - 1];dataSet.push_back(s);}}//切分从文件读取的行vector<string> DataSet::SplitLine(const string &str){vector<string> v;bool isFirstBlank = true;string::size_type pos=0;for (string::size_type i = 0; i != str.size(); ++i){if (str[i] == '\t'&&isFirstBlank){v.push_back(string(str,pos,i-pos));isFirstBlank = false;}if (str[i] != '\t' &&isFirstBlank == false){isFirstBlank = true;pos = i;}}if (isFirstBlank){v.push_back(string(str, pos, str.size() - pos));}return v;}//打印数据集void DataSet::Print(){if (attributesNames.size() == 0){cout << "数据集为空" << endl;return;}for (vector<string>::iterator nIter = attributesNames.begin(); nIter != attributesNames.end();++nIter){cout << *nIter<<"\t";}cout << endl;for (vector<Sample>::iterator sIter = dataSet.begin(); sIter != dataSet.end(); ++sIter){for (vector<string>::iterator aIter = sIter->attributes.begin(); aIter != sIter->attributes.end(); ++aIter){cout << *aIter << "\t";}cout << sIter->targetAttributes;cout << endl;}}//信息增益计算double DataSet::Gain(const string &featureName){vector<string>::iterator findIter = find(attributesNames.begin(), attributesNames.end(), featureName);if (findIter == attributesNames.end()) throw "参数错误";vector<string>::size_type index;//数据集的列标签for (index = 0; index != attributesNames.size(); ++index){if (attributesNames[index] == featureName){break;}}map<string, int> targetMap;//统计各个属性的样本数map<string,map<string, int> > featureMap;//键为for (vector<Sample>::iterator sIter = dataSet.begin(); sIter != dataSet.end(); ++sIter){++targetMap[sIter->targetAttributes];++featureMap[sIter->attributes[index]][sIter->targetAttributes];}vector<int> targetCount;//统计目标变量for (map<string, int>::iterator tarIter = targetMap.begin(); tarIter != targetMap.end(); ++tarIter){targetCount.push_back(tarIter->second);}double gain = Entropy(targetCount);//总的熵int sTotal = dataSet.size();//总的样本数for (map<string, map<string, int> >::iterator featIter = featureMap.begin(); featIter != featureMap.end(); ++featIter){vector<int> featureCout;for (map<string, int>::iterator featIter2 = featIter->second.begin(); featIter2 != featIter->second.end(); ++featIter2){featureCout.push_back(featIter2->second);}int s = accumulate(featureCout.begin(), featureCout.end(),0);//特征出现的总次数gain -= 1.0 * s / sTotal*Entropy(featureCout);}return gain;}//计算信息熵double DataSet::Entropy(const vector<int> &v){double entropy=0.0;int totalNum = accumulate(v.begin(),v.end(),0);for (vector<int>::size_type i = 0; i != v.size(); ++i){int temp = v[i];double p =1.0* temp/totalNum;  //注意类型转化entropy -= p*log2(p);}return entropy;}//根据属性和属性的值为featureValue的子集DataSet DataSet::SplitDataSet(const string &featureName, const string &featureValue){vector<string>::iterator fIter = find(attributesNames.begin(), attributesNames.end(), featureName);if (attributesNames.size() == 0|| fIter==attributesNames.end()) throw "参数错误";DataSet children;              //数据集的子集vector<string>::size_type index;for (index = 0; index != attributesNames.size(); ++index)//找到属性标签的序号{if (attributesNames[index] == featureName){break;}}vector<Sample>::iterator dIter;for (dIter = dataSet.begin(); dIter != dataSet.end(); ++dIter){if (dIter->attributes[index] == featureValue){children.dataSet.push_back(*dIter);//把满足条件的样本放入childrenSet}}children.attributesNames = this->attributesNames;//去除childrenSet已经使用过的属性vector<string>::iterator eNiter = find(children.attributesNames.begin(), children.attributesNames.end(), featureName);if (eNiter!=children.attributesNames.end()){children.attributesNames.erase(eNiter);}vector<Sample>::iterator chilDataIter = children.dataSet.begin();for (; chilDataIter != children.dataSet.end(); ++chilDataIter){vector<string>::iterator eFiter = find(chilDataIter->attributes.begin(), chilDataIter->attributes.end(), featureValue);if (eFiter != chilDataIter->attributes.end()){chilDataIter->attributes.erase(eFiter);}}return children;}//选择子集string DataSet::BestSplitFeature(){double maxGain = 0.0;//最大信息增益string bestSplit;    //具有最大信息增益的属性vector<string>::iterator tIter = attributesNames.begin();for (; tIter != attributesNames.end(); ++tIter){if (maxGain < Gain(*tIter)){maxGain = Gain(*tIter);bestSplit = *tIter;}}return bestSplit;}//判断目标属性是否唯一bool DataSet::IsSameTarget(){set<string> targetContine;vector<Sample>::iterator sIter = dataSet.begin();for (; sIter != dataSet.end(); ++sIter){targetContine.insert(sIter->targetAttributes);}if (targetContine.size() == 1){return true;}else{return false;}}//返回最长出现的属性string DataSet::CommomTargetValue(){map<string, int> targetValue;vector<Sample>::iterator dIter = dataSet.begin();for (; dIter != dataSet.end(); ++dIter){++targetValue[dIter->targetAttributes];}map<string, int>::iterator mIter = targetValue.begin();string commomTarget;int maxValue = 0;for (; mIter != targetValue.end(); ++mIter){if (maxValue < mIter->second){maxValue = mIter->second;commomTarget = mIter->first;}}return commomTarget;}int DataSet::GetAttributeIndex(const string &attributeName){vector<string>::size_type index;for (index = 0; index != attributesNames.size(); ++index){if (attributesNames[index] == attributeName){return index;}}if (index == attributesNames.size()){throw "属性不存在";}}

DesctionTree.h

#ifndef DESCITION_H#define DESCITION_H#include<string>#include<vector>#include<map>#include"DataSet.h"using  std::string;using std::vector;class DesctionTree{private:struct Node{string value;map<string, Node* > children;};Node *root;Node* CreateNode(DataSet &trainSet);void ShowNode(Node *rNode,int level);string ClassOneSample(Node* rNode,vector<string> &v,vector<string> &attributeNames);public:DesctionTree();//~DesctionTree();void CreateTree(DataSet &trainSet);vector<string> ClassTest(DataSet &examples);void ShowTree();void ClearTree(Node* rNode);};#endif

Desction.cpp

#include<vector>#include<string>#include<iostream>#include"DesctionTree.h"using namespace std;DesctionTree::DesctionTree(){root = NULL;}void DesctionTree::CreateTree(DataSet &trainSet){root = new Node;root = CreateNode(trainSet);}void DesctionTree::ShowTree(){if (root != NULL){ShowNode(root,1);}}DesctionTree::Node* DesctionTree::CreateNode(DataSet &trainSet){if (trainSet.dataSet.size() == 0){cout << "数据为空" << endl;return NULL;}Node* rootNode= new Node;if (trainSet.attributesNames.size() == 0){rootNode->value = trainSet.CommomTargetValue();return rootNode;}if (trainSet.IsSameTarget()){vector<DataSet::Sample>::iterator sIter = trainSet.dataSet.begin();rootNode->value = sIter->targetAttributes;return rootNode;}string bestSplitAttributes = trainSet.BestSplitFeature();rootNode->value = bestSplitAttributes;vector<string>::size_type index;for (index = 0; index != trainSet.attributesNames.size();++index){if (trainSet.attributesNames[index] == bestSplitAttributes){break;}}set<string> valueSet;vector<DataSet::Sample>::iterator sIter2 = trainSet.dataSet.begin();for (; sIter2 != trainSet.dataSet.end(); ++sIter2){valueSet.insert(sIter2->attributes[index]);}set<string>::iterator setIter = valueSet.begin();for (; setIter != valueSet.end(); ++setIter){DataSet childSet = trainSet.SplitDataSet(bestSplitAttributes,*setIter);rootNode->children[*setIter] = CreateNode(childSet);}return rootNode;}void DesctionTree::ShowNode(Node *rNode,int level){cout << rNode->value << '\n';if (rNode->children.size() == 0){return;}map<string, Node*>::iterator mIter = rNode->children.begin();for (; mIter != rNode->children.end(); ++mIter){for (int j = 0; j < level; ++j){cout << '\t';}cout << mIter->first << "->";ShowNode(mIter->second,level+1);}}vector<string> DesctionTree::ClassTest(DataSet &examples){vector<string> retDesction;//决策属性vector<DataSet::Sample>::iterator sIter=examples.dataSet.begin();for (; sIter != examples.dataSet.end(); ++sIter){retDesction.push_back(ClassOneSample(root, sIter->attributes, examples.attributesNames));}return retDesction;}string DesctionTree::ClassOneSample(Node* rNode,vector<string> &v, vector<string> &attributeNames){vector<string>::size_type index;if (rNode->children.size() == 0){return rNode->value;}vector<string>::iterator result = find(attributeNames.begin(), attributeNames.end(), rNode->value);if (result == attributeNames.end()){return " ";}index = result - attributeNames.begin();map<string, Node*>::iterator mIter = rNode->children.begin();for (; mIter != rNode->children.end(); ++mIter){if (mIter->first == v[index]){vector<string> tempV(v);vector<string> tempAttributeNames(attributeNames);tempV.erase(tempV.begin()+index);tempAttributeNames.erase(tempAttributeNames.begin() + index);return ClassOneSample(mIter->second, tempV, tempAttributeNames);}}}/*void DesctionTree::ClearTree(Node *rNode){if (rNode->children.empty()){delete rNode;rNode = NULL;}for (map<string, Node*>::iterator mIter = rNode->children.begin(); mIter != rNode->children.end(); ++mIter){ClearTree(mIter->second);}}*/

main.cpp

/************************************ id3决策树** @author:郑午** @time:2014-06-19***********************************/#include<iostream>#include<vector>#include<set>#include<iterator>#include"DataSet.h"#include"DesctionTree.h"using namespace std;int main(){//训练DataSet dataMat("data.txt");DesctionTree id3Tree;id3Tree.CreateTree(dataMat);cout << "树结构:"<<endl;id3Tree.ShowTree();//测试DataSet testSet("test.txt");vector<string> result=id3Tree.ClassTest(testSet);cout << "测试结果:" << endl<<endl;copy(result.begin(), result.end(), ostream_iterator<string>(cout, " "));cout << endl<<endl;system("pause");}



数据文件

Day OutlookTemperateHumidityWindPlayTennis
D1 Sunny Hot High Weak No
D2 Sunny Hot High Strong No
D3 Overcast Hot High Weak Yes
D4 Rain Mild High Weak Yes
D5 Rain Cool Normal Weak Yes
D6 Rain Cool Normal Strong No
D7 Overcast Cool Normal Strong Yes
D8 Sunny Mild High Weak No
D9 Sunny Cool Normal Weak Yes
D10 Rain Mild Normal Weak Yes
D11 Sunny Mild Normal Strong Yes
D12 Overcast Mild High Strong Yes
D13 Overcast Hot Normal Weak Yes
D14 Rain Mild High Strong No


测试数据是训练数据的一个子集。

运行结果:


0 0
原创粉丝点击