DBSCAN聚类算法C++实现

来源:互联网 发布:知乎 判断腿长 编辑:程序博客网 时间:2024/05/14 18:42

这几天由于工作需要,对DBSCAN聚类算法进行了C++的实现。时间复杂度O(n^2),主要花在算每个点领域内的点上。算法很简单,现共享大家参考,也希望有更多交流。

 数据点类型描述如下:

复制代码
 1 #include <vector> 2  3 using namespace std; 4  5 const int DIME_NUM=2;        //数据维度为2,全局常量 6  7 //数据点类型 8 class DataPoint 9 {10 private:11     unsigned long dpID;                //数据点ID12     double dimension[DIME_NUM];        //维度数据13     long clusterId;                    //所属聚类ID14     bool isKey;                        //是否核心对象15     bool visited;                    //是否已访问16     vector<unsigned long> arrivalPoints;    //领域数据点id列表17 public:18     DataPoint();                                                    //默认构造函数19     DataPoint(unsigned long dpID,double* dimension , bool isKey);    //构造函数20 21     unsigned long GetDpId();                //GetDpId方法22     void SetDpId(unsigned long dpID);        //SetDpId方法23     double* GetDimension();                    //GetDimension方法24     void SetDimension(double* dimension);    //SetDimension方法25     bool IsKey();                            //GetIsKey方法26     void SetKey(bool isKey);                //SetKey方法27     bool isVisited();                        //GetIsVisited方法28     void SetVisited(bool visited);            //SetIsVisited方法29     long GetClusterId();                    //GetClusterId方法30     void SetClusterId(long classId);        //SetClusterId方法31     vector<unsigned long>& GetArrivalPoints();    //GetArrivalPoints方法32 };
复制代码

这是实现:

复制代码
 1 #include "DataPoint.h" 2  3 //默认构造函数 4 DataPoint::DataPoint() 5 { 6 } 7  8 //构造函数 9 DataPoint::DataPoint(unsigned long dpID,double* dimension , bool isKey):isKey(isKey),dpID(dpID)10 {11     //传递每维的维度数据12     for(int i=0; i<DIME_NUM;i++)13     {14         this->dimension[i]=dimension[i];15     }16 }17 18 //设置维度数据19 void DataPoint::SetDimension(double* dimension)20 {21     for(int i=0; i<DIME_NUM;i++)22     {23         this->dimension[i]=dimension[i];24     }25 }26 27 //获取维度数据28 double* DataPoint::GetDimension()29 {30     return this->dimension;31 }32 33 //获取是否为核心对象34 bool DataPoint::IsKey()35 {36     return this->isKey;37 }38 39 //设置核心对象标志40 void DataPoint::SetKey(bool isKey)41 {42     this->isKey = isKey;43 }44 45 //获取DpId方法46 unsigned long DataPoint::GetDpId()47 {48     return this->dpID;49 }50 51 //设置DpId方法52 void DataPoint::SetDpId(unsigned long dpID)53 {54     this->dpID = dpID;55 }56 57 //GetIsVisited方法58 bool DataPoint::isVisited()59 {60     return this->visited;61 }62 63 64 //SetIsVisited方法65 void DataPoint::SetVisited( bool visited )66 {67     this->visited = visited;68 }69 70 //GetClusterId方法71 long DataPoint::GetClusterId()72 {73     return this->clusterId;74 }75 76 //GetClusterId方法77 void DataPoint::SetClusterId( long clusterId )78 {79     this->clusterId = clusterId;80 }81 82 //GetArrivalPoints方法83 vector<unsigned long>& DataPoint::GetArrivalPoints()84 {85     return arrivalPoints;86 }
复制代码

DBSCAN算法类型描述:

复制代码
 1 #include <iostream> 2 #include <cmath> 4  5 using namespace std; 6  7 //聚类分析类型 8 class ClusterAnalysis 9 {10 private:11     vector<DataPoint> dadaSets;        //数据集合12     unsigned int dimNum;            //维度13     double radius;                    //半径14     unsigned int dataNum;            //数据数量15     unsigned int minPTs;            //邻域最小数据个数16 17     double GetDistance(DataPoint& dp1, DataPoint& dp2);                    //距离函数18     void SetArrivalPoints(DataPoint& dp);                                //设置数据点的领域点列表19     void KeyPointCluster( unsigned long i, unsigned long clusterId );    //对数据点领域内的点执行聚类操作20 public:21 22     ClusterAnalysis(){}                    //默认构造函数23     bool Init(char* fileName, double radius, int minPTs);    //初始化操作24     bool DoDBSCANRecursive();            //DBSCAN递归算法25     bool WriteToFile(char* fileName);    //将聚类结果写入文件26 };
复制代码

 聚类实现:

复制代码
  1 #include "ClusterAnalysis.h"  2 #include <fstream>  3 #include <iosfwd>  4 #include <math.h>  5   6 /*  7 函数:聚类初始化操作  8 说明:将数据文件名,半径,领域最小数据个数信息写入聚类算法类,读取文件,把数据信息读入写进算法类数据集合中  9 参数: 10 char* fileName;    //文件名 11 double radius;    //半径 12 int minPTs;        //领域最小数据个数   13 返回值: true;    */ 14 bool ClusterAnalysis::Init(char* fileName, double radius, int minPTs) 15 { 16     this->radius = radius;        //设置半径 17     this->minPTs = minPTs;        //设置领域最小数据个数 18     this->dimNum = DIME_NUM;    //设置数据维度 19     ifstream ifs(fileName);        //打开文件 20     if (! ifs.is_open())                //若文件已经被打开,报错误信息 21     { 22         cout << "Error opening file";    //输出错误信息 23         exit (-1);                        //程序退出 24     } 25  26     unsigned long i=0;            //数据个数统计 27     while (! ifs.eof() )                //从文件中读取POI信息,将POI信息写入POI列表中 28     { 29         DataPoint tempDP;                //临时数据点对象 30         double tempDimData[DIME_NUM];    //临时数据点维度信息 31         for(int j=0; j<DIME_NUM; j++)    //读文件,读取每一维数据 32         { 33             ifs>>tempDimData[j]; 34         } 35         tempDP.SetDimension(tempDimData);    //将维度信息存入数据点对象内 36  37 //char date[20]=""; 38 //char time[20]=""; 39         ////double type;    //无用信息 40         //ifs >> date; 41 //ifs >> time;    //无用信息读入 42  43         tempDP.SetDpId(i);                    //将数据点对象ID设置为i 44         tempDP.SetVisited(false);            //数据点对象isVisited设置为false 45         tempDP.SetClusterId(-1);            //设置默认簇ID为-1 46         dadaSets.push_back(tempDP);            //将对象压入数据集合容器 47         i++;        //计数+1 48     } 49     ifs.close();        //关闭文件流 50     dataNum =i;            //设置数据对象集合大小为i 51     for(unsigned long i=0; i<dataNum;i++) 52     { 53         SetArrivalPoints(dadaSets[i]);            //计算数据点领域内对象 54     } 55     return true;    //返回 56 } 57  58 /* 59 函数:将已经过聚类算法处理的数据集合写回文件 60 说明:将已经过聚类结果写回文件 61 参数: 62 char* fileName;    //要写入的文件名 63 返回值: true    */ 64 bool ClusterAnalysis::WriteToFile(char* fileName ) 65 { 66     ofstream of1(fileName);                                //初始化文件输出流 67     for(unsigned long i=0; i<dataNum;i++)                //对处理过的每个数据点写入文件 68     { 69         for(int d=0; d<DIME_NUM ; d++)                    //将维度信息写入文件 70             of1<<dadaSets[i].GetDimension()[d]<<'\t'; 71         of1 << dadaSets[i].GetClusterId() <<endl;        //将所属簇ID写入文件 72     } 73     of1.close();    //关闭输出文件流 74     return true;    //返回 75 } 76  77 /* 78 函数:设置数据点的领域点列表 79 说明:设置数据点的领域点列表 80 参数: 81 返回值: true;    */ 82 void ClusterAnalysis::SetArrivalPoints(DataPoint& dp) 83 { 84     for(unsigned long i=0; i<dataNum; i++)                //对每个数据点执行 85     { 86         double distance =GetDistance(dadaSets[i], dp);    //获取与特定点之间的距离 87         if(distance <= radius && i!=dp.GetDpId())        //若距离小于半径,并且特定点的id与dp的id不同执行 88             dp.GetArrivalPoints().push_back(i);            //将特定点id压力dp的领域列表中 89     } 90     if(dp.GetArrivalPoints().size() >= minPTs)            //若dp领域内数据点数据量> minPTs执行 91     { 92         dp.SetKey(true);    //将dp核心对象标志位设为true 93         return;                //返回 94     } 95     dp.SetKey(false);    //若非核心对象,则将dp核心对象标志位设为false 96 } 97  98  99 /*100 函数:执行聚类操作101 说明:执行聚类操作102 参数:103 返回值: true;    */104 bool ClusterAnalysis::DoDBSCANRecursive()105 {106     unsigned long clusterId=0;                        //聚类id计数,初始化为0107     for(unsigned long i=0; i<dataNum;i++)            //对每一个数据点执行108     {109         DataPoint& dp=dadaSets[i];                    //取到第i个数据点对象110         if(!dp.isVisited() && dp.IsKey())            //若对象没被访问过,并且是核心对象执行111         {112             dp.SetClusterId(clusterId);                //设置该对象所属簇ID为clusterId113             dp.SetVisited(true);                    //设置该对象已被访问过114             KeyPointCluster(i,clusterId);            //对该对象领域内点进行聚类115             clusterId++;                            //clusterId自增1116         }117         //cout << "孤立点\T" << i << endl;118     }119 120     cout <<"共聚类" <<clusterId<<""<< endl;        //算法完成后,输出聚类个数121     return true;    //返回122 }123 124 /*125 函数:对数据点领域内的点执行聚类操作126 说明:采用递归的方法,深度优先聚类数据127 参数:128 unsigned long dpID;            //数据点id129 unsigned long clusterId;    //数据点所属簇id130 返回值: void;    */131 void ClusterAnalysis::KeyPointCluster(unsigned long dpID, unsigned long clusterId )132 {133     DataPoint& srcDp = dadaSets[dpID];        //获取数据点对象134     if(!srcDp.IsKey())    return;135     vector<unsigned long>& arrvalPoints = srcDp.GetArrivalPoints();        //获取对象领域内点ID列表136     for(unsigned long i=0; i<arrvalPoints.size(); i++)137     {138         DataPoint& desDp = dadaSets[arrvalPoints[i]];    //获取领域内点数据点139         if(!desDp.isVisited())                            //若该对象没有被访问过执行140         {141             //cout << "数据点\t"<< desDp.GetDpId()<<"聚类ID为\t" <<clusterId << endl;142             desDp.SetClusterId(clusterId);        //设置该对象所属簇的ID为clusterId,即将该对象吸入簇中143             desDp.SetVisited(true);                //设置该对象已被访问144             if(desDp.IsKey())                    //若该对象是核心对象145             {146                 KeyPointCluster(desDp.GetDpId(),clusterId);    //递归地对该领域点数据的领域内的点执行聚类操作,采用深度优先方法147             }148         }149     }150 }151 152 //两数据点之间距离153 /*154 函数:获取两数据点之间距离155 说明:获取两数据点之间的欧式距离156 参数:157 DataPoint& dp1;        //数据点1158 DataPoint& dp2;        //数据点2159 返回值: double;    //两点之间的距离        */160 double ClusterAnalysis::GetDistance(DataPoint& dp1, DataPoint& dp2)161 {162     double distance =0;        //初始化距离为0163     for(int i=0; i<DIME_NUM;i++)    //对数据每一维数据执行164     {165         distance += pow(dp1.GetDimension()[i] - dp2.GetDimension()[i],2);    //距离+每一维差的平方166     }167     return pow(distance,0.5);        //开方并返回距离168 }
复制代码

算法调用就简单了:

复制代码
 1 #include "ClusterAnalysis.h" 2 #include <cstdio> 3  4 using namespace std; 5  6 int main() 7 { 8     ClusterAnalysis myClusterAnalysis;                        //聚类算法对象声明 9     myClusterAnalysis.Init("D:\\1108\\XY.txt",500,9);        //算法初始化操作,指定半径为15,领域内最小数据点个数为3,(在程序中已指定数据维度为2)10     myClusterAnalysis.DoDBSCANRecursive();                    //执行聚类算法11     myClusterAnalysis.WriteToFile("D:\\1108\\XYResult.txt");//写执行后的结果写入文件12 13     system("pause");    //显示结果14     return 0;            //返回15 }
复制代码
0 0