K-Means的简单实现

来源:互联网 发布:淘宝号怎么升心 编辑:程序博客网 时间:2024/05/16 05:34
#include <iostream>#include <cstdio>#include <cstdlib>#include <algorithm>#include <utility>using namespace std;const int MAXN(100010);const int MAXK(100010);const int MAXC(5);template<typename T>inline bool checkmax(T &a, const T &b){    return a > b? ((a = b), true): false;}template<typename T>inline T ABS(T a){    return a < 0? -a: a;}struct PO{  //点定义    double co[MAXC];} po[MAXN];PO reg[MAXK];  //基准点int qua[MAXK]; //每类包含点数int bel[MAXN], ind[MAXN]; //点i属于bel[i],初始选择基准点所需索引double dis(const PO &a, const PO &b, int C){    double ret = 0;    for(int i = 0; i < C; ++i) ret += (a.co[i]-b.co[i])*(a.co[i]-b.co[i]);    return ret;}void ini_reg(int n, int K){  //从n个点中随机选择K个点    srand(2357);    for(int i = 0; i < n; ++i) ind[i] = i;    for(int i = 0; i < K; ++i){        int j = ((double)rand()/RAND_MAX)*(n-i)+i;        swap(ind[i], ind[j]);    }    for(int i = 0; i < K; ++i) reg[i] = po[ind[i]];}void cal_reg(int n, int K, int C){  //重新计算基准点    for(int i = 0; i < K; ++i){        qua[i] = 0;        for(int j = 0; j < C; ++j)            reg[i].co[j] = 0;    }    for(int i = 0; i < n; ++i){        ++qua[bel[i]];        for(int j = 0; j < C; ++j) reg[bel[i]].co[j] += po[i].co[j];    }    for(int i = 0; i < K; ++i)        for(int j = 0; j < C; ++j)            reg[i].co[j] /= qua[i];}void solve(int n, int K, int C, double err){  //K-means    bool fir(true);    double F_, F(0);    do{        if(fir){            ini_reg(n, K);            fir = false;        }        else            cal_reg(n, K, C);        F_ = F;        F = 0;        for(int j = 0; j < n; ++j){            int id = 0;            double md = dis(po[j], reg[0], C);            for(int k = 1; k < K; ++k){                double t = dis(po[j], reg[k], C);                if(t < md){                    md = t;                    id = k;                }            }            bel[j] = id;            F += md;        }    } while(ABS(F_-F) > err);}int main(){    int n, K, C;  //点数,类数,向量维数    double err;  //阀值    //input    scanf("%d%d%d%lf", &n, &K, &C, &err);    for(int i = 0; i < n; ++i)        for(int j = 0; j < C; ++j)            scanf("%lf", po[i].co+j);    solve(n, K, C, err);    printf("\n");    for(int i = 0; i < n; ++i) printf(" %d", bel[i]);    //end    return 0;}/*20 4 2 0.20.1 0.10.1 0.20.2 0.10.5 0.51 11.1 2.21.1 2.31.1 2.41.1 2.51.5 1.53.1 -1.33.1 -1.43.1 -1.53.1 -1.63.5 -1.74.1 3.54.1 3.64.1 3.74.1 3.84.1 5.0*/

0 0
原创粉丝点击