R机器学习算法系列——KNN

来源:互联网 发布:中国石油大学 知乎 编辑:程序博客网 时间:2024/06/02 05:36

K近邻算法原理

下图中,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

  K 最近邻 (k-Nearest Neighbor,KNN) 分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一,1968年由 Cover 和 Hart 提出。该方法的思路是:如果一个样本在特征空间中的 k 个最相似即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN 算法中,所选择的邻居都是已经正确分类的对象。该方法在分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 K 最近邻 (k-Nearest Neighbor,KNN) 分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:如果一个样本在特征空间中的 k 个最相似即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN 算法中,所选择的邻居都是已经正确分类的对象。该方法在分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。

  KNN 算法本身简单有效,它是一种 lazy-learning 算法,分类器不需要使用训练集进行训练,训练时间复杂度为0。KNN 分类的计算复杂度和训练集中的文档数目成正比,也就是说,如果训练集中文档总数为 n,那么 KNN 的分类时间复杂度为O(n)。

  KNN 方法虽然从原理上也依赖于极限定理,但在类别决策时,只与极少量的相邻样本有关。由于 KNN 方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN 方法较其他方法更为适合。

  K 近邻算法使用的模型实际上对应于对特征空间的划分。K 值的选择,距离度量和分类决策规则是该算法的三个基本要素:

  1. K 值的选择会对算法的结果产生重大影响。K值较小意味着只有与输入实例较近的训练实例才会对预测结果起作用,但容易发生过拟合;如果 K 值较大,优点是可以减少学习的估计误差,但缺点是学习的近似误差增大,这时与输入实例较远的训练实例也会对预测起作用,是预测发生错误。在实际应用中,K 值一般选择一个较小的数值,通常采用交叉验证的方法来选择最有的 K 值。随着训练实例数目趋向于无穷和 K=1 时,误差率不会超过贝叶斯误差率的2倍,如果K也趋向于无穷,则误差率趋向于贝叶斯误差率。
  2. 该算法中的分类决策规则往往是多数表决,即由输入实例的 K 个最临近的训练实例中的多数类决定输入实例的类别
  3. 距离度量一般采用 Lp 距离,当p=2时,即为欧氏距离,在度量之前,应该将每个属性的值规范化,这样有助于防止具有较大初始值域的属性比具有较小初始值域的属性的权重过大。

  该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的 K 个邻居中大容量类的样本占多数。 该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的 K 个邻居中大容量类的样本占多数。 

  该算法只计算“最近的”邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论怎样,数量并不能影响运行结果。可以采用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的 K 个最近邻点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。该算法只计算“最近的”邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论怎样,数量并不能影响运行结果。可以采用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的 K 个最近邻点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。

  KNN 算法不仅可以用于分类,还可以用于回归。通过找出一个样本的 K 个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成正比。

  实现 K 近邻算法时,主要考虑的问题是如何对训练数据进行快速  K 近邻搜索,这在特征空间维数大及训练数据容量大时非常必要。

以上摘自:http://baike.so.com/doc/867876-917614.html

-----R实现-----

# 导入数据wbcd <- read.csv("wisc_bc_data.csv", stringsAsFactors = FALSE)# examine the structure of the wbcd data framestr(wbcd)## 'data.frame':    569 obs. of  32 variables:##  $ id               : int  87139402 8910251 905520 868871 9012568 906539 925291 87880 862989 89827 ...##  $ diagnosis        : chr  "B" "B" "B" "B" ...##  $ radius_mean      : num  12.3 10.6 11 11.3 15.2 ...##  $ texture_mean     : num  12.4 18.9 16.8 13.4 13.2 ...##  $ perimeter_mean   : num  78.8 69.3 70.9 73 97.7 ...##  $ area_mean        : num  464 346 373 385 712 ...##  $ smoothness_mean  : num  0.1028 0.0969 0.1077 0.1164 0.0796 ...##  $ compactness_mean : num  0.0698 0.1147 0.078 0.1136 0.0693 ...##  $ concavity_mean   : num  0.0399 0.0639 0.0305 0.0464 0.0339 ...##  $ points_mean      : num  0.037 0.0264 0.0248 0.048 0.0266 ...##  $ symmetry_mean    : num  0.196 0.192 0.171 0.177 0.172 ...##  $ dimension_mean   : num  0.0595 0.0649 0.0634 0.0607 0.0554 ...##  $ radius_se        : num  0.236 0.451 0.197 0.338 0.178 ...##  $ texture_se       : num  0.666 1.197 1.387 1.343 0.412 ...##  $ perimeter_se     : num  1.67 3.43 1.34 1.85 1.34 ...##  $ area_se          : num  17.4 27.1 13.5 26.3 17.7 ...##  $ smoothness_se    : num  0.00805 0.00747 0.00516 0.01127 0.00501 ...##  $ compactness_se   : num  0.0118 0.03581 0.00936 0.03498 0.01485 ...##  $ concavity_se     : num  0.0168 0.0335 0.0106 0.0219 0.0155 ...##  $ points_se        : num  0.01241 0.01365 0.00748 0.01965 0.00915 ...##  $ symmetry_se      : num  0.0192 0.035 0.0172 0.0158 0.0165 ...##  $ dimension_se     : num  0.00225 0.00332 0.0022 0.00344 0.00177 ...##  $ radius_worst     : num  13.5 11.9 12.4 11.9 16.2 ...##  $ texture_worst    : num  15.6 22.9 26.4 15.8 15.7 ...##  $ perimeter_worst  : num  87 78.3 79.9 76.5 104.5 ...##  $ area_worst       : num  549 425 471 434 819 ...##  $ smoothness_worst : num  0.139 0.121 0.137 0.137 0.113 ...##  $ compactness_worst: num  0.127 0.252 0.148 0.182 0.174 ...##  $ concavity_worst  : num  0.1242 0.1916 0.1067 0.0867 0.1362 ...##  $ points_worst     : num  0.0939 0.0793 0.0743 0.0861 0.0818 ...##  $ symmetry_worst   : num  0.283 0.294 0.3 0.21 0.249 ...##  $ dimension_worst  : num  0.0677 0.0759 0.0788 0.0678 0.0677 ...# 舍弃分类特征wbcd <- wbcd[-1]# table of diagnosistable(wbcd$diagnosis)## ##   B   M ## 357 212# recode diagnosis as a factorwbcd$diagnosis <- factor(wbcd$diagnosis, levels = c("B", "M"),                         labels = c("Benign", "Malignant"))# table or proportions with more informative labelsround(prop.table(table(wbcd$diagnosis)) * 100, digits = 1)## ##    Benign Malignant ##      62.7      37.3# summarize three numeric featuressummary(wbcd[c("radius_mean", "area_mean", "smoothness_mean")])##   radius_mean       area_mean      smoothness_mean  ##  Min.   : 6.981   Min.   : 143.5   Min.   :0.05263  ##  1st Qu.:11.700   1st Qu.: 420.3   1st Qu.:0.08637  ##  Median :13.370   Median : 551.1   Median :0.09587  ##  Mean   :14.127   Mean   : 654.9   Mean   :0.09636  ##  3rd Qu.:15.780   3rd Qu.: 782.7   3rd Qu.:0.10530  ##  Max.   :28.110   Max.   :2501.0   Max.   :0.16340# 创建规范化函数normalize <- function(x) {  return ((x - min(x)) / (max(x) - min(x)))}# test normalization function - result should be identicalnormalize(c(1, 2, 3, 4, 5))## [1] 0.00 0.25 0.50 0.75 1.00normalize(c(10, 20, 30, 40, 50))## [1] 0.00 0.25 0.50 0.75 1.00# 规范化数据集wbcd_n <- as.data.frame(lapply(wbcd[2:31], normalize))# confirm that normalization workedsummary(wbcd_n$area_mean)##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. ##  0.0000  0.1174  0.1729  0.2169  0.2711  1.0000#创建训练集和测试集(469:100)wbcd_train <- wbcd_n[1:469, ]wbcd_test <- wbcd_n[470:569, ]# create labels for training and test datawbcd_train_labels <- wbcd[1:469, 1]wbcd_test_labels <- wbcd[470:569, 1]## Step 3: 训练模型数据 ----# 加载class包library(class)##测试数据分类结果wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test,                      cl = wbcd_train_labels, k = 21)## Step 4: Evaluating model performance ----# 加载<span style="font-family: 微软雅黑; font-style: inherit;">gmodels包</span>library(gmodels)## Warning: package 'gmodels' was built under R version 3.3.3# Create the cross tabulation of predicted vs. actualCrossTable(x = wbcd_test_labels, y = wbcd_test_pred,           prop.chisq = FALSE)## ##  ##    Cell Contents## |-------------------------|## |                       N |## |           N / Row Total |## |           N / Col Total |## |         N / Table Total |## |-------------------------|## ##  ## Total Observations in Table:  100 ## ##  ##                  | wbcd_test_pred ## wbcd_test_labels |    Benign | Malignant | Row Total | ## -----------------|-----------|-----------|-----------|##           Benign |        61 |         0 |        61 | ##                  |     1.000 |     0.000 |     0.610 | ##                  |     0.968 |     0.000 |           | ##                  |     0.610 |     0.000 |           | ## -----------------|-----------|-----------|-----------|##        Malignant |         2 |        37 |        39 | ##                  |     0.051 |     0.949 |     0.390 | ##                  |     0.032 |     1.000 |           | ##                  |     0.020 |     0.370 |           | ## -----------------|-----------|-----------|-----------|##     Column Total |        63 |        37 |       100 | ##                  |     0.630 |     0.370 |           | ## -----------------|-----------|-----------|-----------|## ## ## Step 5: 提升模型 ----# 标准化数据集wbcd_z <- as.data.frame(scale(wbcd[-1]))# confirm that the transformation was applied correctlysummary(wbcd_z$area_mean)##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. ## -1.4530 -0.6666 -0.2949  0.0000  0.3632  5.2460# create training and test datasetswbcd_train <- wbcd_z[1:469, ]wbcd_test <- wbcd_z[470:569, ]# 重新测试分类wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test,                      cl = wbcd_train_labels, k = 21)# Create the cross tabulation of predicted vs. actualCrossTable(x = wbcd_test_labels, y = wbcd_test_pred,           prop.chisq = FALSE)## ##  ##    Cell Contents## |-------------------------|## |                       N |## |           N / Row Total |## |           N / Col Total |## |         N / Table Total |## |-------------------------|## ##  ## Total Observations in Table:  100 ## ##  ##                  | wbcd_test_pred ## wbcd_test_labels |    Benign | Malignant | Row Total | ## -----------------|-----------|-----------|-----------|##           Benign |        61 |         0 |        61 | ##                  |     1.000 |     0.000 |     0.610 | ##                  |     0.924 |     0.000 |           | ##                  |     0.610 |     0.000 |           | ## -----------------|-----------|-----------|-----------|##        Malignant |         5 |        34 |        39 | ##                  |     0.128 |     0.872 |     0.390 | ##                  |     0.076 |     1.000 |           | ##                  |     0.050 |     0.340 |           | ## -----------------|-----------|-----------|-----------|##     Column Total |        66 |        34 |       100 | ##                  |     0.660 |     0.340 |           | ## -----------------|-----------|-----------|-----------|## ## # 尝试使用其他的K值wbcd_train <- wbcd_n[1:469, ]wbcd_test <- wbcd_n[470:569, ]wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test, cl = wbcd_train_labels, k=1)CrossTable(x = wbcd_test_labels, y = wbcd_test_pred, prop.chisq=FALSE)## ##  ##    Cell Contents## |-------------------------|## |                       N |## |           N / Row Total |## |           N / Col Total |## |         N / Table Total |## |-------------------------|## ##  ## Total Observations in Table:  100 ## ##  ##                  | wbcd_test_pred ## wbcd_test_labels |    Benign | Malignant | Row Total | ## -----------------|-----------|-----------|-----------|##           Benign |        58 |         3 |        61 | ##                  |     0.951 |     0.049 |     0.610 | ##                  |     0.983 |     0.073 |           | ##                  |     0.580 |     0.030 |           | ## -----------------|-----------|-----------|-----------|##        Malignant |         1 |        38 |        39 | ##                  |     0.026 |     0.974 |     0.390 | ##                  |     0.017 |     0.927 |           | ##                  |     0.010 |     0.380 |           | ## -----------------|-----------|-----------|-----------|##     Column Total |        59 |        41 |       100 | ##                  |     0.590 |     0.410 |           | ## -----------------|-----------|-----------|-----------|## ## wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test, cl = wbcd_train_labels, k=5)CrossTable(x = wbcd_test_labels, y = wbcd_test_pred, prop.chisq=FALSE)## ##  ##    Cell Contents## |-------------------------|## |                       N |## |           N / Row Total |## |           N / Col Total |## |         N / Table Total |## |-------------------------|## ##  ## Total Observations in Table:  100 ## ##  ##                  | wbcd_test_pred ## wbcd_test_labels |    Benign | Malignant | Row Total | ## -----------------|-----------|-----------|-----------|##           Benign |        61 |         0 |        61 | ##                  |     1.000 |     0.000 |     0.610 | ##                  |     0.968 |     0.000 |           | ##                  |     0.610 |     0.000 |           | ## -----------------|-----------|-----------|-----------|##        Malignant |         2 |        37 |        39 | ##                  |     0.051 |     0.949 |     0.390 | ##                  |     0.032 |     1.000 |           | ##                  |     0.020 |     0.370 |           | ## -----------------|-----------|-----------|-----------|##     Column Total |        63 |        37 |       100 | ##                  |     0.630 |     0.370 |           | ## -----------------|-----------|-----------|-----------|## ## wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test, cl = wbcd_train_labels, k=11)CrossTable(x = wbcd_test_labels, y = wbcd_test_pred, prop.chisq=FALSE)## ##  ##    Cell Contents## |-------------------------|## |                       N |## |           N / Row Total |## |           N / Col Total |## |         N / Table Total |## |-------------------------|## ##  ## Total Observations in Table:  100 ## ##  ##                  | wbcd_test_pred ## wbcd_test_labels |    Benign | Malignant | Row Total | ## -----------------|-----------|-----------|-----------|##           Benign |        61 |         0 |        61 | ##                  |     1.000 |     0.000 |     0.610 | ##                  |     0.953 |     0.000 |           | ##                  |     0.610 |     0.000 |           | ## -----------------|-----------|-----------|-----------|##        Malignant |         3 |        36 |        39 | ##                  |     0.077 |     0.923 |     0.390 | ##                  |     0.047 |     1.000 |           | ##                  |     0.030 |     0.360 |           | ## -----------------|-----------|-----------|-----------|##     Column Total |        64 |        36 |       100 | ##                  |     0.640 |     0.360 |           | ## -----------------|-----------|-----------|-----------|## ## wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test, cl = wbcd_train_labels, k=15)CrossTable(x = wbcd_test_labels, y = wbcd_test_pred, prop.chisq=FALSE)## ##  ##    Cell Contents## |-------------------------|## |                       N |## |           N / Row Total |## |           N / Col Total |## |         N / Table Total |## |-------------------------|## ##  ## Total Observations in Table:  100 ## ##  ##                  | wbcd_test_pred ## wbcd_test_labels |    Benign | Malignant | Row Total | ## -----------------|-----------|-----------|-----------|##           Benign |        61 |         0 |        61 | ##                  |     1.000 |     0.000 |     0.610 | ##                  |     0.953 |     0.000 |           | ##                  |     0.610 |     0.000 |           | ## -----------------|-----------|-----------|-----------|##        Malignant |         3 |        36 |        39 | ##                  |     0.077 |     0.923 |     0.390 | ##                  |     0.047 |     1.000 |           | ##                  |     0.030 |     0.360 |           | ## -----------------|-----------|-----------|-----------|##     Column Total |        64 |        36 |       100 | ##                  |     0.640 |     0.360 |           | ## -----------------|-----------|-----------|-----------|## ## wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test, cl = wbcd_train_labels, k=21)CrossTable(x = wbcd_test_labels, y = wbcd_test_pred, prop.chisq=FALSE)## ##  ##    Cell Contents## |-------------------------|## |                       N |## |           N / Row Total |## |           N / Col Total |## |         N / Table Total |## |-------------------------|## ##  ## Total Observations in Table:  100 ## ##  ##                  | wbcd_test_pred ## wbcd_test_labels |    Benign | Malignant | Row Total | ## -----------------|-----------|-----------|-----------|##           Benign |        61 |         0 |        61 | ##                  |     1.000 |     0.000 |     0.610 | ##                  |     0.968 |     0.000 |           | ##                  |     0.610 |     0.000 |           | ## -----------------|-----------|-----------|-----------|##        Malignant |         2 |        37 |        39 | ##                  |     0.051 |     0.949 |     0.390 | ##                  |     0.032 |     1.000 |           | ##                  |     0.020 |     0.370 |           | ## -----------------|-----------|-----------|-----------|##     Column Total |        63 |        37 |       100 | ##                  |     0.630 |     0.370 |           | ## -----------------|-----------|-----------|-----------|## ## wbcd_test_pred <- knn(train = wbcd_train, test = wbcd_test, cl = wbcd_train_labels, k=27)CrossTable(x = wbcd_test_labels, y = wbcd_test_pred, prop.chisq=FALSE)## ##  ##    Cell Contents## |-------------------------|## |                       N |## |           N / Row Total |## |           N / Col Total |## |         N / Table Total |## |-------------------------|## ##  ## Total Observations in Table:  100 ## ##  ##                  | wbcd_test_pred ## wbcd_test_labels |    Benign | Malignant | Row Total | ## -----------------|-----------|-----------|-----------|##           Benign |        61 |         0 |        61 | ##                  |     1.000 |     0.000 |     0.610 | ##                  |     0.938 |     0.000 |           | ##                  |     0.610 |     0.000 |           | ## -----------------|-----------|-----------|-----------|##        Malignant |         4 |        35 |        39 | ##                  |     0.103 |     0.897 |     0.390 | ##                  |     0.062 |     1.000 |           | ##                  |     0.040 |     0.350 |           | ## -----------------|-----------|-----------|-----------|##     Column Total |        65 |        35 |       100 |##                  |     0.650 |     0.350 |           |## -----------------|-----------|-----------|-----------|

总结:

         1.采用不同的数据清洗规则或不同的数据处理方式对结果产生不同的影响;
          2.采用不同的K值,结果也会不尽相同。



0 0