KNN算法及R语言实现

来源:互联网 发布:java实验指导书答案 编辑:程序博客网 时间:2024/04/25 01:59

        KNN(k-Nearest Neighbor)分类算法是数据挖掘分类技术中较简单的方法之一。所谓k最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。


        例如,上图中,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

        KNN分类算法,是一个理论上比较成熟的方法,也是较简单的机器学习算法之一。该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 KNN方法虽然从原理上也依赖于极限定理,但在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成反比。


KNN算法流程: 

1. 准备数据,对数据进行预处理

2. 选用合适的数据结构存储训练数据和测试元组

3. 设定参数,如k

4. 维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列

5. 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L 与优先级队列中的最大距离Lmax

6. 进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。

7. 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。

8. 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。


KNN算法优点:

1. 简单,易于理解,易于实现,无需估计参数,无需训练;

2. 适合对稀有事件进行分类;

3. 特别适合于多分类问题(multi-modal,对象具有多个类别标签), kNN比SVM的表现要好;


KNN算法缺点:

1. 当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。 该算法只计算“最近的”邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论怎样,数量并不能影响运行结果;

2. 计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点;

3. 可理解性差,无法给出像决策树那样的规则;


R语言中有kknn的package实现了weighted k-nearest neighbor,用法如下:

kknn(formula = formula(train), train, test, na.action = na.omit(), k = 7, distance = 2, kernel = "optimal", ykernel = NULL, scale=TRUE, contrasts = c('unordered' = "contr.dummy", ordered = "contr.ordinal"))


参数:

  • formula                            A formula object.
  • train                                 Matrix or data frame of training set cases.
  • test                                   Matrix or data frame of test set cases.
  • na.action                         A function which indicates what should happen when the data contain ’NA’s.
  • k                                       Number of neighbors considered.
  • distance                          Parameter of Minkowski distance.
  • kernel                              Kernel to use. Possible choices are "rectangular" (which is standard unweighted knn), "triangular", "epanechnikov" (or beta(2,2)), "biweight" (or beta(3,3)), "triweight" (or beta(4,4)), "cos", "inv", "gaussian", "rank" and "optimal".
  • ykernel                            Window width of an y-kernel, especially for prediction of ordinal classes.
  • scale                                Logical, scale variable to have equal sd.
  • contrasts                         A vector containing the ’unordered’ and ’ordered’ contrasts to use
kknn的返回值如下:
  • fitted.values              Vector of predictions.
  • CL                              Matrix of classes of the k nearest neighbors.
  • W                                Matrix of weights of the k nearest neighbors.
  • D                                 Matrix of distances of the k nearest neighbors.
  • C                                 Matrix of indices of the k nearest neighbors.
  • prob                            Matrix of predicted class probabilities.
  • response                   Type of response variable, one of continuous, nominal or ordinal.
  • distance                     Parameter of Minkowski distance.
  • call                              The matched call.
  • terms                          The ’terms’ object used.

实例:

library(kknn)
</pre><pre name="code" class="plain">#iris数据结构如下
#    Sepal.Length Sepal.Width Petal.Length Petal.Width    Species#1            5.1         3.5          1.4         0.2     setosa#2            4.9         3.0          1.4         0.2     setosa#3            4.7         3.2          1.3         0.2     setosa
#...
#51           7.0         3.2          4.7         1.4     versicolor
#...
data(iris)
</pre><pre name="code" class="plain"># 将iris的行数赋给mm <- dim(iris)[1]
# 选取采样数据val <- sample(1:m, size = round(m/3), replace = FALSE, prob = rep(1/m, m))
# 建立训练数据iris.learn <- iris[-val,]
# 建立测试数据iris.valid <- iris[val,]
# 调用kknn,formula Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Widthiris.kknn <- kknn(Species~., iris.learn, iris.valid, distance = 1, kernel = "triangular")
summary(iris.kknn)
# 获取fitted.valuesfit <- fitted(iris.kknn)
# 建立表格检验判类准确性table(iris.valid$Species, fit)
#             setosa versicolor virginica#  setosa         15          0         0#  versicolor      0         15         1#  virginica       0          3        16
</pre><pre name="code" class="plain"># 绘画散点图,k-nearest neighbor用红色高亮显示
pcol <- as.character(as.numeric(iris.valid$Species))pairs(iris.valid[1:4], pch = pcol, col = c("green3", "red")[(iris.valid$Species != fit)+1])




源码Github:https://github.com/bigdata-william/R_Algorithm/blob/master/kknn.R

微信公众号:威廉的大数据实验室

2 0