python机器学习-交叉验证(cross-vaildation)

来源:互联网 发布:算法工程师的年薪 编辑:程序博客网 时间:2024/06/05 17:23

K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。

所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。

1.基础验证法:

import numpy as npfrom sklearn.datasets import load_iris # iris数据集from sklearn.model_selection import train_test_splitfrom sklearn.neighbors import KNeighborsClassifier # K最近邻(kNN,k-NearestNeighbor)分类算法import matplotlib.pyplot as plt#加载iris数据集iris = load_iris()X = iris.datay = iris.target#分割数据,random_state设定随机数种子X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=4)#建立模型knn = KNeighborsClassifier()#训练模型knn.fit(X_train, y_train)#将准确率打印出print(knn.score(X_test, y_test))# 0.973684210526


2.交叉验证:求平均值

import numpy as npfrom sklearn.datasets import load_iris # iris数据集from sklearn.model_selection import train_test_splitfrom sklearn.neighbors import KNeighborsClassifier # K最近邻(kNN,k-NearestNeighbor)分类算法from sklearn.model_selection import cross_val_scoreimport matplotlib.pyplot as plt#加载iris数据集iris = load_iris()X = iris.datay = iris.targetknn = KNeighborsClassifier()#对数据集进行5次随机划分为训练集和测试集,并对应5次测试,返回5次测试准确率accuracy的列表scores=cross_val_score(knn,X,y,cv=5,scoring='accuracy')#5组的准确率print(scores)#输出平均值#[ 0.96666667  1.          0.93333333  0.96666667  1.        ]print(scores.mean())#0.973333333333

利用交叉验证调节参数:测试选择合适的k值使得准确率最大

import numpy as npfrom sklearn.datasets import load_iris # iris数据集from sklearn.model_selection import train_test_splitfrom sklearn.neighbors import KNeighborsClassifier # K最近邻(kNN,k-NearestNeighbor)分类算法from sklearn.model_selection import cross_val_scoreimport matplotlib.pyplot as plt#加载iris数据集iris = load_iris()X = iris.datay = iris.target#建立测试参数K数据集k_range = range(1, 31)k_scores = []#由迭代的方式来计算不同k对模型的影响,并返回交叉验证后的平均准确率for k in k_range:    #采用k个最近邻    knn = KNeighborsClassifier(n_neighbors=k)    scores = cross_val_score(knn, X, y, cv=10, scoring='accuracy')    k_scores.append(scores.mean())#可视化数据plt.plot(k_range, k_scores)plt.xlabel('Value of K for KNN')plt.ylabel('Cross-Validated Accuracy')plt.show()

从上图可以看出,k在12~18最好,超过18之后,准确率下降是因为过拟合的出现





阅读全文
0 0
原创粉丝点击