MapReduce实现KNN

来源:互联网 发布:淘宝明通数码科技公司 编辑:程序博客网 时间:2024/06/03 10:30

不正之处,欢迎指正。

        KNN算法称为K近邻分类算法,是最简单的分类器,KNN算法从训练集中找到和测试数据距离最近的K个记录,然后根据这K个记录的标记来决定测试实例的最终标记。MapReduce作为一种大数据环境下的计算模型,在分布式计算中具有其独特的优势,本文主要在hadoop框架下面实现KNN算法。

        实验环境:centos6.5+hadoop2.2.0

实验步骤:

         MapReduce的关键之处在于实现用户自定义的map和Reduce函数,在本实例中,我们在mapper类中的clean函数中首先读取所有的训练数据,用一个List来进行存储。在map阶段,逐行读取每一个测试实例,计算测试实例和训练数据之间的距离,找到最近的k个距离所对应的标记。在Reduce阶段,通过统计map阶段的标记信息,找到出现次数最多的标记就是最终的测试用例标记。实验代码如下:

package org.apache.hadoop.knn;import java.io.BufferedReader;import java.io.IOException;import java.io.InputStreamReader;import java.net.URI;import java.util.ArrayList;import java.util.HashMap;import java.util.Iterator;import java.util.Set;import org.apache.hadoop.conf.Configuration;import org.apache.hadoop.fs.FSDataInputStream;import org.apache.hadoop.fs.FileSystem;import org.apache.hadoop.fs.Path;import org.apache.hadoop.io.LongWritable;import org.apache.hadoop.io.NullWritable;import org.apache.hadoop.io.Text;import org.apache.hadoop.mapreduce.Job;import org.apache.hadoop.mapreduce.Mapper;import org.apache.hadoop.mapreduce.Reducer;import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;public class Knn {public static class KnnMap extends Mapper<LongWritable, Text, Text, Text> {public ArrayList<Instance> train = new ArrayList<Instance>();       //存储训练集public int k = 5;@Override//读取训练集protected void setup(Mapper<LongWritable, Text, Text, Text>.Context context)throws IOException, InterruptedException {// TODO Auto-generated method stub// super.setup(context);FileSystem fs = null;try {fs = FileSystem.get(new URI("hdfs://192.168.1.119:9000"), new Configuration());} catch (Exception e) {}FSDataInputStream fi = fs.open(new Path("hdfs://192.168.1.119:9000/data/traindata.txt"));BufferedReader bf = new BufferedReader(new InputStreamReader(fi));String line = bf.readLine();while (line != null) {Instance sample = new Instance(line);train.add(sample);line = bf.readLine();}}@Overrideprotected void map(LongWritable key, Text value, Context context)throws IOException, InterruptedException {// TODO Auto-generated method stub// super.map(key, value, context);ArrayList<Double> distance = new ArrayList<Double>(k);ArrayList<String> trainlabel = new ArrayList<String>(k);for (int i = 0; i < k; i++) {distance.add(Double.MAX_VALUE);trainlabel.add(String.valueOf("-1.0"));}Instance test = new Instance(value.toString());for (int i = 0; i < train.size(); i++) {double dis = Distance(train.get(i).getFeatures(),test.getFeatures());for (int j = 0; j < k; j++) {if (dis < (Double) distance.get(j)) {distance.set(j, dis);trainlabel.set(j, train.get(i).getLabel() + "");break;}}}for (int i = 0; i < k; i++) {context.write(new Text(value.toString()),new Text(trainlabel.get(i) + ""));}}private double Distance(double[] a, double[] b) {// TODO Auto-generated method stubdouble sum = 0.0;for (int i = 0; i < a.length; i++) {sum += Math.pow(a[i] - b[i], 2);}return Math.sqrt(sum);}}public static class KnnReducer extendsReducer<Text, Text, Text, NullWritable> {@Overrideprotected void reduce(Text k, Iterable<Text> values, Context context)throws IOException, InterruptedException {// TODO Auto-generated method stub// super.reduce(arg0, arg1, arg2);ArrayList<String> l = new ArrayList<String>();for (Text t : values) {l.add(t.toString());}String predict = Predict(l);context.write(new Text(k.toString() + "\t" + predict),NullWritable.get());}private String Predict(ArrayList<String> arr) {// TODO Auto-generated method stubHashMap<String, Double> tmp = new HashMap<String, Double>();for (int i = 0; i < arr.size(); i++) {if (tmp.containsKey(arr.get(i))) {double frequence = tmp.get(arr.get(i)) + 1;tmp.remove(arr.get(i));tmp.put((String) arr.get(i), frequence);} elsetmp.put((String) arr.get(i), new Double(1));}Set<String> s = tmp.keySet();Iterator it = s.iterator();double lablemax = Double.MIN_VALUE;String predictlable = null;while (it.hasNext()) {String key = (String) it.next();Double lablenum = tmp.get(key);if (lablenum > lablemax) {lablemax = lablenum;predictlable = key;}}return predictlable;}}public static void main(String[] args) throws IOException,ClassNotFoundException, InterruptedException {FileSystem fs = FileSystem.get(new Configuration());Job job = new Job(new Configuration());job.setJarByClass(Knn.class);FileInputFormat.setInputPaths(job, new Path(args[0]));job.setMapperClass(KnnMap.class);job.setMapOutputKeyClass(Text.class);job.setMapOutputValueClass(Text.class);FileOutputFormat.setOutputPath(job, new Path(args[1]));job.setReducerClass(KnnReducer.class);job.setOutputKeyClass(Text.class);job.setOutputValueClass(NullWritable.class);job.waitForCompletion(true);}}



0 0
原创粉丝点击