【Petuum 源码解析】之K-Means分布式算法源码

来源:互联网 发布:golang java性能对比 编辑:程序博客网 时间:2024/06/05 10:10

【前言】

由于最近在看分布式机器学习相关的东西,所以希望能把学习心得记录在这里。本系列主要介绍CMU的开源分布式框架Petuum源码的学习。由于本人水平很菜,希望大家不吝赐教,共同学习。
本文主要记录我在学习K-Means源码当中的收获和疑问。

【正文部分】

1.Petuum

首先看看Petuum的组成部分,主要包括Bösen(a bounded-asynchronous distributed key-value store)和Strads( a dynamic ML update scheduler.),详细的说明可以参考[官方文档][1]

2.关于K-Means可以参考Wiki的介绍:

https://en.wikipedia.org/wiki/K-means_clustering
本文中对KMeans的解读,需要弄清的问题是:
1.KMeans算法在Petuum系统是是如何实现并行的?
2.参数是如何在参数服务器(Parameter Server)与Worker之间进行更新和分配的?
3.在效率方面有哪些提升?
4.有没有数学理论上的保证?

3. 源码解读:

3.1 代码结构:

K-Means在Petuum中是用C++来实现的,源码位于bosen/app/kmeans目录下,![可以参考该截图](http://img.blog.csdn.net/20160413170722728)

3.2 核心代码:

核心代码包括kmeans.h, kmeans.cpp, kmeans_main.cpp, kmeans_methods.cpp, kmeans_methods.h, kmeans_worker.cpp, kmeans_worker.h 。主要就是这7个文件。还有几个工具文件:cluster_centers.cpp,cluster_centers.h,context.cpp, context.hpp,dataset.cpp,dataset.h,dense_vector.cpp,dense_vector.h,sparse_vector.cpp,sparse_vector.h。当然还用到了许多第三方库,包括boost标准库,以及gflags,glog等google的库。下面主要来介绍前面这几个核心文件。

3.2.1 kmeans.h

有这么几个函数需要注意:

void ReadData();//读入数据int GetTrainingDataSize();const dataset GetReferenceToDataSet();void Start();//线程入口函数void SetExamplesPerBatch(long number_of_samples);

另外几个重要的成员变量,是PS Table的变量

      // ============ PS Tables ============      petuum::Table<float> centres_;      petuum::Table<int> count_of_centers_;      petuum::Table<int> delta_centers_;      petuum::Table<float> objective_values_;
  1. ReadData:读取数据
  2. Start: 线程入口函数。首先注册thread(petuum::PSTableGroup的RegisterThread);然后通过Context获取Petuum的相关参数;然后生成k个cluster质心;然后进行循环的更新质心数据到参数服务器()
    如下所示:
for (int i = 0; i < num_centers; i++) {    petuum::UpdateBatch<float> update_batch;    for (int j = 0; j < dimensionality; j++) {      if (client_id == 0 && thread_id == 0) {        update_batch.Update(j,            initial_centres.getCenterAt(i).ValueAt(j));      } else {        update_batch.Update(j, 0);      }    }    centres_.BatchInc(i, update_batch);  }

3.2.2 kmeans_methods.h

int randInt(int num);void RandomInitializeCenters(cluster_centers* centers, const dataset& dataset1);float ComputeObjective(cluster_centers* centers, const dataset& ds);

未完~

[1]http://petuum.github.io/

0 0
原创粉丝点击