【机器学习】聚类算法:ISODATA算法

来源:互联网 发布:mysql slave status 编辑:程序博客网 时间:2024/05/01 18:25

在之前的K-Means算法中,有两大缺陷:

 

      (1)K值是事先选好的固定的值

      (2)随机种子选取可能对结果有影响

 

针对缺陷(2),我们提出了K-Means++算法,它使得随机种子选取非常合理,进而使得算法更加完美。但是缺

(1)终没有解决,也就是说在K-Means算法中K值得选取是事先选好固定的一个值,当时也提出ISODATA算

可以找到合适的K,现在就来详细讲述ISODATA算法的原理,并会给出C++代码。

 

Contents

 

   1. ISODATA算法的认识

   2. ISODATA的参数介绍

   3. ISODATA的C++实现

 

 

1. ISODATA算法的认识

 

   ISODATA算法全称为Iterative Self Organizing Data Analysis Techniques Algorithm,即迭代

   自组织数据分析方法。ISODATA算法通过设置初始参数而引入人机对话环节,并使用归并和分裂等机制,当两类

   聚中心小于某个阀值时,将它们合并为一类。当某类的标准差大于某一阀值时或其样本数目超过某一阀值时,将其

   分裂为两类,某类样本数目小于某一阀值时,将其取消。这样根据初始类聚中心和设定的类别数目等参数迭代,

   最终得到一个比较理想的分类结果。ISODATA算法是一种常用的聚类分析方法,是一种非监督学习方法。

 

 

2. ISODATA的参数介绍

 

   上面介绍了ISODATA算法的大致原理,在ISODATA算法中有6个重要的参数。

 

    expClusters  预期的类聚中心数

    thetaN       一个类别至少应该具有的样本数目,小于此数目就不作为一个独立的聚类

    thetaS       一个类别样本的标准差阀值

    thetaC       类聚中心之间距离的阀值,即归并系数,若小于此数,则两个类进行合并

    maxIts       允许迭代的最多次数

    combL        在一次迭代中可以归并的类别的最多对数

   

   有了如上参数,接下来就开始进行迭代了。

 

 

3. ISODATA的C++实现

 

   ISODATA算法的详细步骤可以参考如下代码

 

#include <iostream>#include <string.h>#include <algorithm>#include <stdio.h>#include <vector>#include <assert.h>#include <math.h>#define iniClusters 5  //初始类聚的个数using namespace std;//定义6个使用的参数struct Args{int expClusters;   //期望得到的聚类数int thetaN;        //聚类中最少样本数int maxIts;        //最大迭代次数int combL;         //每次迭代允许合并的最大聚类对数double thetaS;     //标准偏差参数double thetaC;     //合并参数}args;//定义二维点,这里假设是二维的特征,当然可以推广到多维struct Point{double x, y;};//需要合并的两个类聚的信息,包括两个类聚的id和距离struct MergeInfo{int u, v;double d;    //类聚u中心与类聚v中心的距离};//定义比较函数bool cmp(MergeInfo a, MergeInfo b){return a.d < b.d;}//计算两点之间距离double dist(Point A, Point B){return sqrt((A.x - B.x) * (A.x - B.x) + (A.y - B.y) * (A.y - B.y));}struct Cluster{int nSamples;          //样本点的个数double avgDist;        //样本点到样本中心的平均距离Point center;          //样本中心Point sigma;           //样本与中心的标准差vector<Point *> data;  //聚类的数据//计算该聚类的中心,即该类的均值void calMean(){assert(nSamples == data.size());for(int i = 0; i < nSamples; i++){center.x += data.at(i)->x;center.y += data.at(i)->y;}center.x /= nSamples;center.y /= nSamples;}//计算该类样本点到该聚类中心得平均距离void calDist(){avgDist = 0;for(int i = 0; i < nSamples; i++)avgDist += dist(*(data.at(i)), center);avgDist /= nSamples;}//计算样本与中心的标准差void calStErr(){assert(nSamples == data.size());double attr1 = 0;double attr2 = 0;        //样本的两个维度for(int i = 0; i < nSamples; i++){attr1 += (data.at(i)->x - center.x) * (data.at(i)->x - center.x);attr2 += (data.at(i)->y - center.y) * (data.at(i)->y - center.y);}sigma.x = sqrt(attr1 / nSamples);sigma.y = sqrt(attr2 / nSamples);}};//获取数据void getData(Point p[], int n){cout << "getting data..." << endl;for(int i = 0; i < n; i++)scanf("%lf %lf", &p[i].x, &p[i].y);cout << "get data done!" << endl;}//设置参数的值void setArgs(){args.expClusters = 5;args.thetaN = 3;args.maxIts = 10000;args.combL = 10;args.thetaS = 3;args.thetaC = 0.001;}//寻找点t距离最近的类的中心对应的idint FindIdx(vector<Cluster> &c, Point &t){int nClusters = c.size();assert(nClusters >= 1);double ans = dist(c.at(0).center, t);int idx = 0;for(int i = 1; i < nClusters; i++){double tmp = dist(c.at(i).center, t);if(ans > tmp){idx = i;ans = tmp;}}return idx;}//二分法寻找距离刚好小于thetaC的两个类聚的indexint FindPos(MergeInfo *info, int n, double thetaC){int l = 0;int r = n - 1;while(l <= r){int mid = (l + r) >> 1;if(info[mid].d < thetaC){l = mid + 1;if(l < n && info[l].d >= thetaC)return mid;}else{r = mid - 1;if(r >= 0 && info[r].d < thetaC)return r;}}if(info[n - 1].d < thetaC)return n - 1;elsereturn -1;}void Print(const vector<Cluster> c){int n = c.size();for(int i = 0; i < n; i++){cout << "------------------------------------" << endl;cout << "第" << i + 1 << "个聚类是:" << endl;for(int j = 0; j < c.at(i).data.size(); j++)cout << "(" << c[i].data[j]->x << "," << c[i].data[j]->y << ")  ";cout << endl;cout << endl;}}void ISOData(Point p[], int n){cout << "ISOData is processing......." << endl;vector<Cluster> c;              //每个类聚的数据const double split = 0.5;       //分裂常数(0,1]int nClusters = iniClusters;    //初始化类聚个数//初始化nClusters个类,设置相关数据for(int i = 0; i < nClusters; i++){Cluster t;t.center = p[i];t.nSamples = 0;t.avgDist = 0;c.push_back(t);}int iter = 0;bool isLess = false;            //标志是否有类的数目低于thetaNwhile(1){//先清空每一个聚类for(int i = 0; i < nClusters; i++){c.at(i).nSamples = 0;c.at(i).data.clear();}//将所有样本划分到距离类聚中心最近的类中for(int i = 0; i < n; i++){int idx = FindIdx(c, p[i]);c.at(idx).data.push_back(&p[i]);c.at(idx).nSamples++;}int k = 0;                   //记录样本数目低于thetaN的类的indexfor(int i = 0; i < nClusters; i++){if(c.at(i).data.size() < args.thetaN){isLess = true;       //说明样本数过少,该类应该删除k = i;break;}}//如果有类的样本数目小于thetaNif(isLess){nClusters--;Cluster t = c.at(k);vector<Cluster>::iterator pos = c.begin() + k;c.erase(pos);assert(nClusters == c.size());for(int i = 0; i < t.data.size(); i++){int idx = FindIdx(c, *(t.data.at(i)));c.at(idx).data.push_back(t.data.at(i));c.at(idx).nSamples++;}isLess = false;}//重新计算均值和样本到类聚中心的平均距离for(int i = 0; i < nClusters; i++){c.at(i).calMean();c.at(i).calDist();}//计算总的平均距离double totalAvgDist = 0;for(int i = 0; i < nClusters; i++)totalAvgDist += c.at(i).avgDist * c.at(i).nSamples;totalAvgDist /= n;if(iter >= args.maxIts) break;//分裂操作if(nClusters <= args.expClusters / 2){vector<double> maxsigma;for(int i = 0; i < nClusters; i++){//计算该类的标准偏差c.at(i).calStErr();    //计算该类标准差的最大分量double mt = c.at(i).sigma.x > c.at(i).sigma.y? c.at(i).sigma.x : c.at(i).sigma.y;maxsigma.push_back(mt);}for(int i = 0; i < nClusters; i++){if(maxsigma.at(i) > args.thetaS){if((c.at(i).avgDist > totalAvgDist && c.at(i).nSamples > 2 * (args.thetaN + 1)) || (nClusters < args.expClusters / 2)){nClusters++;Cluster newCtr;     //新的聚类中心//获取新的中心newCtr.center.x = c.at(i).center.x - split * c.at(i).sigma.x;newCtr.center.y = c.at(i).center.y - split * c.at(i).sigma.y;c.push_back(newCtr);//改变老的中心c.at(i).center.x = c.at(i).center.x + split * c.at(i).sigma.x;c.at(i).center.y = c.at(i).center.y + split * c.at(i).sigma.y;break;}  }}}//合并操作if(nClusters >= 2 * args.expClusters || (iter & 1) == 0){int size = nClusters * (nClusters - 1);//需要合并的聚类个数int cnt = 0;    MergeInfo *info = new MergeInfo[size];for(int i = 0; i < nClusters; i++){for(int j = i + 1; j < nClusters; j++){info[cnt].u = i;info[cnt].v = j;info[cnt].d = dist(c.at(i).center, c.at(j).center);cnt++;}}//进行排序sort(info, info + cnt, cmp);//找出info数组中距离刚好小于thetaC的index,那么index更小的更应该合并int iPos = FindPos(info, cnt, args.thetaC);//用于指示该位置的样本点是否已经合并bool *flag = new bool[nClusters];memset(flag, false, sizeof(bool) * nClusters);//用于标记该位置的样本点是否已经合并删除bool *del = new bool[nClusters];memset(del, false, sizeof(bool) * nClusters);//记录合并的次数int nTimes = 0;for(int i = 0; i <= iPos; i++){int u = info[i].u;int v = info[i].v;//确保同一个类聚只合并一次if(!flag[u] && !flag[v]){nTimes++;//如果一次迭代中合并对数多于combL,则停止合并if(nTimes > args.combL) break;//将数目少的样本合并到数目多的样本中if(c.at(u).nSamples < c.at(v).nSamples){del[u] = true;Cluster t = c.at(u);assert(t.nSamples == t.data.size());for(int j = 0; j < t.nSamples; j++)c.at(v).data.push_back(t.data.at(j));c.at(v).center.x = c.at(v).center.x * c.at(v).nSamples + t.nSamples * t.center.x;c.at(v).center.y = c.at(v).center.y * c.at(v).nSamples + t.nSamples * t.center.y;c.at(v).nSamples += t.nSamples;c.at(v).center.x /= c.at(v).nSamples;c.at(v).center.y /= c.at(v).nSamples;}else{del[v] = true;Cluster t = c.at(v);assert(t.nSamples == t.data.size());for(int j = 0; j < t.nSamples; j++)c.at(u).data.push_back(t.data.at(j));c.at(u).center.x = c.at(u).center.x * c.at(u).nSamples + t.nSamples * t.center.x;c.at(u).center.y = c.at(u).center.y * c.at(u).nSamples + t.nSamples * t.center.y;c.at(u).nSamples += t.nSamples;c.at(u).center.x /= c.at(u).nSamples;c.at(u).center.y /= c.at(u).nSamples;}}}  //删除合并后的聚类vector<Cluster>::iterator id = c.begin();for(int i = 0; i < nClusters; i++){if(del[i])id = c.erase(id);elseid++;}//合并多少次就删除多少个nClusters -= nTimes; assert(nClusters == c.size());delete[] info;delete[] flag;delete[] del;info = NULL;flag = NULL;del = NULL;}if(iter >= args.maxIts) break;iter++;}assert(nClusters == c.size());Print(c);}int main(){int n;scanf("%d", &n);Point *p = new Point[n];getData(p, n);setArgs();ISOData(p, n);delete[] p;p = NULL;return 0;}

还是用上次K-Means中的测试数据,如下

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. 15  
  2. 0 0  
  3. 1 0  
  4. 0 1  
  5. -1 0  
  6. 0 -1  
  7. 10 0  
  8. 11 0  
  9. 9 0  
  10. 10 1  
  11. 10 -1  
  12. -10 0  
  13. -11 0  
  14. -9 0  
  15. -10 1  
  16. -10 -1  


 

输入数据后得到如下结果

 

 

可以看出设置适当的参数后,得到的结果比较理想。

0 0
原创粉丝点击