kmeans集群算法(cluster-reuters)

来源:互联网 发布:java局域网循环聊天 编辑:程序博客网 时间:2024/06/05 18:18



原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: kmeans集群算法(cluster-reuters)

理论分析

集群中心点计算

1 随机从待分类的向量中选出20个作为20个集群的中心。
2 对所有的点,计算其和每个中心的距离,距离最小者为当前点的集群归属。
3 重新对每个集群计算新的中心,并计算新的中心和老的中心的距离,判断其是否收敛。
4 如果所有集群都收敛或者达到用户指定的条件,则集群完成。否则,从2开始下一轮计算。

集群数据

对所有的点,计算其和每个中心的距离,距离最小者为当前点的集群归属。

代码分析

  $MAHOUT kmeans \    -i ${WORK_DIR}/reuters-out-seqdir-sparse-kmeans/tfidf-vectors/ \    -c ${WORK_DIR}/reuters-kmeans-clusters \    -o ${WORK_DIR}/reuters-kmeans \    -dm org.apache.mahout.common.distance.CosineDistanceMeasure \    -x 10 -k 20 -ow --clustering

在这之前同样需要调用seqdirectory和seq2sparse,请参考贝叶斯分类(classify-20newsgroups)

    if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) {      clusters = RandomSeedGenerator.buildRandom(getConf(), input, clusters,          Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)), measure);    }

随机从待集群的文章中选取20篇文字作为20个集群的中心。

    if (runSequential) {      ClusterIterator.iterateSeq(conf, input, priorClustersPath, output, maxIterations);    } else {      ClusterIterator.iterateMR(conf, input, priorClustersPath, output, maxIterations);    }   public static void iterateMR(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations)    throws IOException, InterruptedException, ClassNotFoundException {    ClusteringPolicy policy = ClusterClassifier.readPolicy(priorPath);    Path clustersOut = null;    int iteration = 1;    /* 直到等于迭代次数或isConverged收敛*/    while (iteration <= numIterations) {      conf.set(PRIOR_PATH_KEY, priorPath.toString());       String jobName = "Cluster Iterator running iteration " + iteration + " over priorPath: " + priorPath;      Job job = new Job(conf, jobName);      job.setMapOutputKeyClass(IntWritable.class);      job.setMapOutputValueClass(ClusterWritable.class);      job.setOutputKeyClass(IntWritable.class);      job.setOutputValueClass(ClusterWritable.class);       job.setInputFormatClass(SequenceFileInputFormat.class);      job.setOutputFormatClass(SequenceFileOutputFormat.class);      job.setMapperClass(CIMapper.class);      job.setReducerClass(CIReducer.class);       FileInputFormat.addInputPath(job, inPath);      clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration);      priorPath = clustersOut;      FileOutputFormat.setOutputPath(job, clustersOut);       job.setJarByClass(ClusterIterator.class);      if (!job.waitForCompletion(true)) {        throw new InterruptedException("Cluster Iteration " + iteration + " failed processing " + priorPath);      }      ClusterClassifier.writePolicy(policy, clustersOut);      FileSystem fs = FileSystem.get(outPath.toUri(), conf);      iteration++;      /* 计算每个Cluster的当前的中心点和本次重新计算出来的中心点的距离,如果都小于给定的convergenceDelta,则本次集群计算收敛*/      if (isConverged(clustersOut, conf, fs)) {        break;      }    }    Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX);    FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn);  }   /* CIMapper中的map方法*/  @Override  protected void map(WritableComparable<?> key, VectorWritable value, Context context) throws IOException,      InterruptedException {    /* 使用ClusterClassifier对当前文章进行分类*/    Vector probabilities = classifier.classify(value.get());    Vector selections = policy.select(probabilities);    for (Element el : selections.nonZeroes()) {      classifier.train(el.index(), value.get(), el.get());    }  }   /* ClusterClassifier中classify方法 */  @Override  public Vector classify(Vector instance) {    return policy.classify(instance, this);  }   /* AbstractClusteringPolicy中的classify方法 */  @Override  public Vector classify(Vector data, ClusterClassifier prior) {    List<Cluster> models = prior.getModels();    int i = 0;    Vector pdfs = new DenseVector(models.size());    /* 用20个集群中心模型对当前文章进行分类并存储在pdfs里面*/    for (Cluster model : models) {      pdfs.set(i++, model.pdf(new VectorWritable(data)));    }    return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());  }   /* DistanceMeasureCluster中的pdf方法 */  @Override  public double pdf(VectorWritable vw) {    return 1 / (1 + measure.distance(vw.get(), getCenter()));  }   /* CosineDistanceMeasure中的distance方法,余玄求解2个向量的夹角 */  @Override  public double distance(Vector v1, Vector v2) {    if (v1.size() != v2.size()) {      throw new CardinalityException(v1.size(), v2.size());    }    double lengthSquaredv1 = v1.getLengthSquared();    double lengthSquaredv2 = v2.getLengthSquared();     double dotProduct = v2.dot(v1);    double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2);     // correct for floating-point rounding errors    if (denominator < dotProduct) {      denominator = dotProduct;    }     // correct for zero-vector corner case    if (denominator == 0 && dotProduct == 0) {      return 0;    }     return 1.0 - dotProduct / denominator;  }   /*  ClusterClassifier的train方法*/  public void train(int actual, Vector data, double weight) {    models.get(actual).observe(new VectorWritable(data), weight);  }   /* AbstractCluster中的observe方法,根据weight给s0计数,s1向量累加,s2向量平方后累加*/  @Override  public void observe(VectorWritable x, double weight) {    observe(x.get(), weight);  }   public void observe(Vector x, double weight) {    if (weight == 1.0) {      observe(x);    } else {      setS0(getS0() + weight);      Vector weightedX = x.times(weight);      if (getS1() == null) {        setS1(weightedX);      } else {        getS1().assign(weightedX, Functions.PLUS);      }      Vector x2 = x.times(x).times(weight);      if (getS2() == null) {        setS2(x2);      } else {        getS2().assign(x2, Functions.PLUS);      }    }  }   /* CIReducer中reduce方法,对这一轮加入集群的向量进行平均,从新计算集群中心*/  @Override  protected void reduce(IntWritable key, Iterable<ClusterWritable> values, Context context) throws IOException,      InterruptedException {    Iterator<ClusterWritable> iter = values.iterator();    Cluster first = iter.next().getValue(); // there must always be at least one    while (iter.hasNext()) {      Cluster cluster = iter.next().getValue();      first.observe(cluster);    }    List<Cluster> models = Lists.newArrayList();    models.add(first);    classifier = new ClusterClassifier(models, policy);    classifier.close();    context.write(key, new ClusterWritable(first));  }
0 0