MapReduce的Kmeans聚类算法

来源:互联网 发布:剑网3军娘捏脸数据 编辑:程序博客网 时间:2024/06/06 07:30

最近在网上查看用MapReduce实现的Kmeans算法,例子是不错,http://blog.csdn.net/jshayzf/article/details/22739063

但注释太少了,而且参数太多,如果新手学习的话不太好理解。所以自己按照个人的理解写了一个简单的例子并添加了详细的注释。

大致的步骤是:

1,Map每读取一条数据就与中心做对比,求出该条记录对应的中心,然后以中心的ID为Key,该条数据为value将数据输出。

2,利用reduce的归并功能将相同的Key归并到一起,集中与该Key对应的数据,再求出这些数据的平均值,输出平均值。

3,对比reduce求出的平均值与原来的中心,如果不相同,这将清空原中心的数据文件,将reduce的结果写到中心文件中。(中心的值存在一个HDFS的文件中)

     删掉reduce的输出目录以便下次输出。

     继续运行任务。

4,对比reduce求出的平均值与原来的中心,如果相同。则删掉reduce的输出目录,运行一个没有reduce的任务将中心ID与值对应输出。


package MyKmeans;import java.io.IOException;import java.util.ArrayList;import org.apache.hadoop.conf.Configuration;import org.apache.hadoop.fs.Path;import org.apache.hadoop.io.Text;import java.util.Arrays;import java.util.Iterator;import org.apache.hadoop.io.IntWritable;import org.apache.hadoop.io.LongWritable;import org.apache.hadoop.io.NullWritable;import org.apache.hadoop.mapreduce.Job;import org.apache.hadoop.mapreduce.Mapper;import org.apache.hadoop.mapreduce.Reducer;import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;public class MapReduce {public static class Map extends Mapper<LongWritable, Text, IntWritable, Text>{//中心集合ArrayList<ArrayList<Double>> centers = null;//用k个中心int k = 0;//读取中心protected void setup(Context context) throws IOException,InterruptedException {centers = Utils.getCentersFromHDFS(context.getConfiguration().get("centersPath"),false);k = centers.size();}/** * 1.每次读取一条要分类的条记录与中心做对比,归类到对应的中心 * 2.以中心ID为key,中心包含的记录为value输出(例如: 1 0.2 。  1为聚类中心的ID,0.2为靠近聚类中心的某个值) */protected void map(LongWritable key, Text value, Context context)throws IOException, InterruptedException {//读取一行数据ArrayList<Double> fileds = Utils.textToArray(value);int sizeOfFileds = fileds.size();double minDistance = 99999999;int centerIndex = 0;//依次取出k个中心点与当前读取的记录做计算for(int i=0;i<k;i++){double currentDistance = 0;for(int j=0;j<sizeOfFileds;j++){double centerPoint = Math.abs(centers.get(i).get(j));double filed = Math.abs(fileds.get(j));currentDistance += Math.pow((centerPoint - filed) / (centerPoint + filed), 2);}//循环找出距离该记录最接近的中心点的IDif(currentDistance<minDistance){minDistance = currentDistance;centerIndex = i;}}//以中心点在centers中的索引为Key 将记录原样输出context.write(new IntWritable(centerIndex+1), value);}}//利用reduce的归并功能以中心为Key将记录归并到一起public static class Reduce extends Reducer<IntWritable, Text, NullWritable, Text>{/** * 1.Key为聚类中心的ID value为该中心的记录集合 * 2.计数所有记录元素的平均值,求出新的中心 */protected void reduce(IntWritable key, Iterable<Text> value,Context context)throws IOException, InterruptedException {ArrayList<ArrayList<Double>> filedsList = new ArrayList<ArrayList<Double>>();//依次读取记录集,每行为一个ArrayList<Double>for(Iterator<Text> it =value.iterator();it.hasNext();){ArrayList<Double> tempList = Utils.textToArray(it.next());filedsList.add(tempList);}//计算新的中心//每行的元素个数int filedSize = filedsList.get(0).size();double[] avg = new double[filedSize];for(int i=0;i<filedSize;i++){//求没列的平均值double sum = 0;int size = filedsList.size();for(int j=0;j<size;j++){sum += filedsList.get(j).get(i);}avg[i] = sum / size;}context.write(NullWritable.get() , new Text(Arrays.toString(avg).replace("[", "").replace("]", "")));}}@SuppressWarnings("deprecation")public static void run(String centerPath,String dataPath,String newCenterPath,boolean runReduce) throws IOException, ClassNotFoundException, InterruptedException{Configuration conf = new Configuration();conf.set("centersPath", centerPath);Job job = new Job(conf, "mykmeans");job.setJarByClass(MapReduce.class);job.setMapperClass(Map.class);job.setMapOutputKeyClass(IntWritable.class);job.setMapOutputValueClass(Text.class);if(runReduce){//最后依次输出不许要reducejob.setReducerClass(Reduce.class);job.setOutputKeyClass(NullWritable.class);job.setOutputValueClass(Text.class);}FileInputFormat.addInputPath(job, new Path(dataPath));FileOutputFormat.setOutputPath(job, new Path(newCenterPath));System.out.println(job.waitForCompletion(true));}public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {if(args.length < 3){throw new IllegalArgumentException("需要3个参数,储存centers数据的文件名,存储元数据的文件名,结果目录");}String centerPath = args[0];String dataPath = args[1];String newCenterPath = args[2];centerPath = FileUtil.loadFile(newCenterPath, "MyKmeans", centerPath);dataPath = FileUtil.loadFile(newCenterPath, "MyKmeans", dataPath);FileUtil.deleteFile(newCenterPath);int count = 0;while(true){run(centerPath,dataPath,newCenterPath,true);System.out.println(" 第 " + ++count + " 次计算 ");if(Utils.compareCenters(centerPath,newCenterPath )){run(centerPath,dataPath,newCenterPath,false);break;}}}}

package MyKmeans;import java.io.IOException;import java.util.ArrayList;import java.util.List;import org.apache.hadoop.conf.Configuration;import org.apache.hadoop.fs.FSDataInputStream;import org.apache.hadoop.fs.FSDataOutputStream;import org.apache.hadoop.fs.FileStatus;import org.apache.hadoop.fs.FileSystem;import org.apache.hadoop.fs.Path;import org.apache.hadoop.io.IOUtils;import org.apache.hadoop.io.Text;import org.apache.hadoop.util.LineReader;public class Utils {//读取中心文件的数据public static ArrayList<ArrayList<Double>> getCentersFromHDFS(String centersPath,boolean isDirectory) throws IOException{ArrayList<ArrayList<Double>> result = new ArrayList<ArrayList<Double>>();Path path = new Path(centersPath);Configuration conf = new Configuration();FileSystem fileSystem = path.getFileSystem(conf);if(isDirectory){FileStatus[] listFile = fileSystem.listStatus(path);for (int i = 0; i < listFile.length; i++) {result.addAll(getCentersFromHDFS(listFile[i].getPath().toString(),false));}return result;}FSDataInputStream fsis = fileSystem.open(path);LineReader lineReader = new LineReader(fsis, conf);Text line = new Text();while(lineReader.readLine(line) > 0){ArrayList<Double> tempList = textToArray(line);result.add(tempList);}lineReader.close();return result;}//删掉文件public static void deletePath(String pathStr) throws IOException{Configuration conf = new Configuration();Path path = new Path(pathStr);FileSystem hdfs = path.getFileSystem(conf);hdfs.delete(path ,true);}public static ArrayList<Double> textToArray(Text text){ArrayList<Double> list = new ArrayList<Double>();String[] fileds = text.toString().split(",");for(int i=0;i<fileds.length;i++){list.add(Double.parseDouble(fileds[i]));}return list;}public static boolean compareCenters(String centerPath,String newPath) throws IOException{List<ArrayList<Double>> oldCenters = Utils.getCentersFromHDFS(centerPath,false);List<ArrayList<Double>> newCenters = Utils.getCentersFromHDFS(newPath,true);int size = oldCenters.size();int fildSize = oldCenters.get(0).size();double distance = 0;for(int i=0;i<size;i++){for(int j=0;j<fildSize;j++){double t1 = Math.abs(oldCenters.get(i).get(j));double t2 = Math.abs(newCenters.get(i).get(j));distance += Math.pow((t1 - t2) / (t1 + t2), 2);}}if(distance == 0.0){//删掉新的中心文件以便最后依次归类输出Utils.deletePath(newPath);return true;}else{//先清空中心文件,将新的中心文件复制到中心文件中,再删掉新中心文件Configuration conf = new Configuration();Path outPath = new Path(centerPath);FileSystem fileSystem = outPath.getFileSystem(conf);FSDataOutputStream overWrite = fileSystem.create(outPath,true);overWrite.writeChars("");overWrite.close();Path inPath = new Path(newPath);FileStatus[] listFiles = inPath.getFileSystem(conf).listStatus(inPath);for (int i = 0; i < listFiles.length; i++) {if (listFiles[i].getPath().getName().contains("_SUCCESS")){continue;}FSDataOutputStream out = fileSystem.create(outPath);FSDataInputStream in = fileSystem.open(listFiles[i].getPath());IOUtils.copyBytes(in, out, 4096, true);}//删掉新的中心文件以便第二次任务运行输出Utils.deletePath(newPath);}return false;}}

package MyKmeans;import java.io.IOException;import org.apache.hadoop.conf.Configuration;import org.apache.hadoop.fs.FileSystem;import org.apache.hadoop.fs.Path;/** *  * @author zx * */public class FileUtil {/** * 上传数据文件到hdfs * @param inputPath * @param fileName * @return * @throws IOException */public static String loadFile(String inputPath,String folder,String fileName) throws IOException{//获取数据文件的全路径if(null != folder && !"".equals(folder)){folder = folder + "/";}String srcPathDir = FileUtil.class.getProtectionDomain().getCodeSource().getLocation()                .getFile() + folder + fileName;Path srcpath = new Path("file:///" + srcPathDir);Path dstPath = new Path(getJobRootPath(inputPath) + fileName);Configuration conf = new Configuration();FileSystem fs = dstPath.getFileSystem(conf);fs.delete(dstPath, true);fs.copyFromLocalFile(srcpath, dstPath);fs.close();return getJobRootPath(inputPath) + fileName;}/** * 如果路径的最后不包哈“/”就加一个“/” * @param path * @return */public static String getJobRootPath(String path){if(path.lastIndexOf("/") == path.length()-1){path = path.substring(0, path.lastIndexOf("/"));}return path.substring(0, path.lastIndexOf("/")+1);}public static void deleteFile(String ...filePath) throws IOException{Configuration conf = new Configuration();for (int i = 0; i < filePath.length; i++) {Path path = new Path(filePath[i]);FileSystem fs = path.getFileSystem(conf);fs.delete(path,true);}}}

数据集   http://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data

运行结果可以与 http://blog.csdn.net/jshayzf/article/details/22739063的结果做对比(前提是初始的中心相同)


1 0
原创粉丝点击