数据挖掘笔记-分类-KNN-原理与简单实现

来源:互联网 发布:上瘾网络剧下载种子 编辑:程序博客网 时间:2024/05/16 19:50

原文地址:http://blog.csdn.net/fighting_one_piece/article/details/39030969

KNN算法又称为k近邻分类(k-nearest neighbor classification)算法。最简单平凡的分类器。KNN算法则是从训练集中找到和新数据最接近的k条记录,然后根据他们的主要分类来决定新数据的类别。该算法涉及3个主要因素:训练集、距离、k的大小。可用于客户流失预测、欺诈侦测等(更适合于稀有事件的分类问题)。 


计算步骤如下:
    1)计算距离:给定测试对象,计算它与训练集中的每个对象的距离。距离计算可以选择欧氏距离
曼哈顿距离、余弦距离等。计算距离之前最好对数据进行规范化处理,以便于更好的计算。
    2)寻找邻居:圈定距离最近的k个训练对象,作为测试对象的近邻。K值的选择可以通过若干试验,选取分类误差最小的K值
    3)判断分类:根据这k个近邻归属的主要类别,来对测试对象分类。判定方式主要是投票决定,少数服从多数,近邻中哪个类别的点最多就分为该类。也可以通过加权投票方法来决定。


优点
简单,易于理解,易于实现,无需估计参数,无需训练
适合对稀有事件进行分类(例如当流失率很低时,比如低于0.5%,构造流失预测模型)
特别适合于多分类问题(multi-modal,对象具有多个类别标签),例如根据基因特征来判断其功能分类,KNN比SVM的表现要好

缺点
计算开销大,需要有效的存储技术和并行硬件的支撑。
可解释性较差,无法给出决策树那样的规则。

下面是基于MapReduce简单实现KNN的代码:
[java] view plaincopy
  1. public class KNNClassifier {  
  2.       
  3.     private static void configureJob(Job job) {  
  4.         job.setJarByClass(KNNClassifier.class);  
  5.           
  6.         job.setMapperClass(KNNMapper.class);  
  7.         job.setMapOutputKeyClass(Text.class);  
  8.         job.setMapOutputValueClass(PointWritable.class);  
  9.           
  10.         job.setReducerClass(KNNReducer.class);  
  11.         job.setOutputKeyClass(Text.class);  
  12.         job.setOutputValueClass(Text.class);  
  13.           
  14.         job.setInputFormatClass(TextInputFormat.class);  
  15.         job.setOutputFormatClass(TextOutputFormat.class);  
  16.     }  
  17.       
  18.     public static void main(String[] args) {  
  19.         long start = System.currentTimeMillis();  
  20.         Configuration configuration = new Configuration();  
  21.         try {  
  22.             String[] inputArgs = new GenericOptionsParser(  
  23.                         configuration, args).getRemainingArgs();  
  24.             if (inputArgs.length != 4) {  
  25.                 System.out.println("error, please input three path.");  
  26.                 System.out.println("1 train set path.");  
  27.                 System.out.println("2 test set path.");  
  28.                 System.out.println("3 output path.");  
  29.                 System.out.println("4 k value.");  
  30.                 System.exit(2);  
  31.             }  
  32.             DistributedCache.addCacheFile(new Path(inputArgs[0]).toUri(), configuration);  
  33.               
  34.             configuration.set("k", inputArgs[3]);  
  35.             Job job = new Job(configuration, "KNN Classifier");  
  36.               
  37.             FileInputFormat.setInputPaths(job, new Path(inputArgs[1]));  
  38.             FileOutputFormat.setOutputPath(job, new Path(inputArgs[2]));  
  39.               
  40.             configureJob(job);  
  41.               
  42.             System.out.println(job.waitForCompletion(true) ? 0 : 1);  
  43.             long end = System.currentTimeMillis();  
  44.             System.out.println("spend time: " + (end - start) / 1000);  
  45.         } catch (Exception e) {  
  46.             e.printStackTrace();  
  47.         }  
  48.     }  
  49.   
  50. }  
  51.   
  52. class KNNMapper extends Mapper<LongWritable, Text, Text, PointWritable> {  
  53.       
  54.     private List<Point> trainPoints = new ArrayList<Point>();  
  55.       
  56.     @Override  
  57.     protected void setup(Context context) throws IOException, InterruptedException {  
  58.         super.setup(context);  
  59.         Configuration conf = context.getConfiguration();  
  60.         FileSystem fs = FileSystem.get(conf);  
  61.         URI[] uris = DistributedCache.getCacheFiles(conf);  
  62.         Path[] paths = HDFSUtils.getPathFiles(fs, new Path(uris[0]));  
  63.         for(Path path : paths) {  
  64.             FSDataInputStream in = fs.open(path);  
  65.             BufferedReader reader = new BufferedReader(new InputStreamReader(in));  
  66.             String line = reader.readLine();  
  67.             while (null != line && !"".equals(line)) {  
  68.                 String[] datas = line.split(" ");  
  69.                 trainPoints.add(new Point(Double.parseDouble(datas[0]),   
  70.                         Double.parseDouble(datas[1]), datas[2]));  
  71.                 line = reader.readLine();  
  72.             }  
  73.             IOUtils.closeQuietly(in);  
  74.             IOUtils.closeQuietly(reader);  
  75.         }  
  76.     }  
  77.   
  78.     @Override  
  79.     protected void map(LongWritable key, Text value, Context context)  
  80.             throws IOException, InterruptedException {  
  81.         String line = value.toString();  
  82.         String[] datas = line.split(" ");  
  83.         double x = Double.parseDouble(datas[0]);  
  84.         double y = Double.parseDouble(datas[1]);  
  85.         Point testPoint = new Point(x, y);  
  86.         String outputKey = x + "-" + y;  
  87.         for (Point trainPoint : trainPoints) {  
  88.             double distance = distance(testPoint, trainPoint);  
  89.             context.write(new Text(outputKey), new PointWritable(trainPoint, distance));  
  90.         }  
  91.     }  
  92.       
  93.     public double distance(Point point1, Point point2) {  
  94.         return Math.sqrt(Math.pow((point1.getX() - point2.getX()), 2)  
  95.                 + Math.pow((point1.getY() - point2.getY()), 2));  
  96.     }  
  97.   
  98.     @Override  
  99.     protected void cleanup(Context context) throws IOException, InterruptedException {  
  100.         super.cleanup(context);  
  101.     }  
  102. }  
  103.   
  104. class KNNReducer extends Reducer<Text, PointWritable, Text, Text> {  
  105.   
  106.     private int k = 0;  
  107.       
  108.     @Override  
  109.     protected void setup(Context context) throws IOException, InterruptedException {  
  110.         super.setup(context);  
  111.         Configuration conf = context.getConfiguration();  
  112.         k = Integer.parseInt(conf.get("k""0"));  
  113.     }  
  114.   
  115.     @Override  
  116.     protected void reduce(Text key, Iterable<PointWritable> values,  
  117.             Context context) throws IOException, InterruptedException {  
  118.         System.out.println(key);  
  119.         List<PointWritable> points = new ArrayList<PointWritable>();  
  120.         for (PointWritable point : values) {  
  121.             points.add(new PointWritable(point));  
  122.         }  
  123.         Collections.sort(points, new Comparator<PointWritable>() {  
  124.             @Override  
  125.             public int compare(PointWritable o1, PointWritable o2) {  
  126.                 return o1.getDistance().compareTo(o2.getDistance());  
  127.             }  
  128.         });  
  129.         Map<String, Integer> map = new HashMap<String, Integer>();  
  130.         k = points.size() < k ? points.size() : k;  
  131.         for (int i = 0; i < k; i++) {  
  132.             PointWritable point = points.get(i);  
  133.             String category = point.getCategory().toString();  
  134.             Integer count = map.get(category);  
  135.             map.put(category, null == count ? 1 : count + 1);  
  136.         }  
  137.         List<Map.Entry<String, Integer>> list =   
  138.                 new ArrayList<Map.Entry<String, Integer>>(map.entrySet());  
  139.         Collections.sort(list, new Comparator<Map.Entry<String, Integer>>(){  
  140.             @Override  
  141.             public int compare(Entry<String, Integer> o1,  
  142.                     Entry<String, Integer> o2) {  
  143.                 return o2.getValue().compareTo(o1.getValue());  
  144.             }  
  145.         });  
  146.         context.write(key, new Text(list.get(0).getKey()));  
  147.     }  
  148.   
  149.     @Override  
  150.     protected void cleanup(Context context) throws IOException,  
  151.             InterruptedException {  
  152.         super.cleanup(context);  
  153.     }  
  154.   
  155. }  

附:Spark学习笔记-KNN算法实现

代码托管:https://github.com/fighting-one-piece/repository-datamining.git
0 0