Machine learning for OpenCV 学习笔记 day2
来源:互联网 发布:linux oracle 是否安装 编辑:程序博客网 时间:2024/06/04 23:18
第三部分:算法实现
1.K近邻法(K-NN)
工作原理:存在一个样本数据集合,也称作训练样本集,并且样本中每个数据都存在标签,输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征作比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。
我们使用OpenCV中的 cv2.ml.KNearest_create函数就可以实现这个算法,实现该算法一般遵循以下步骤:
生成训练数据——>选取K个目标点——>分别找到离这K个目标点最近的近邻点——>将这些近邻点注上标签——>输出结果
所以首先我们要先生成训练数据:
import numpy as npimport cv2import matplotlib.pyplot as pltplt.style.use('ggplot')np.random.seed(42)# 0 - 100 之间随机生成2个整数single_data_point = np.random.randint(0,100,2)print(single_data_point)# 生成数据标签,此处标签是2个single_label = np.random.randint(0,2)print(single_label)
先加载我们要用的模块,为了测试,我们先随机生成2个0-100的整数,并且生成它们的标签,0或者是1.
# Generate the training datadef generate_data(num_samples, num_features=2): """Randomly generates a number of data points""" data_size = (num_samples,num_features) train_data = np.random.randint(0,100,size=data_size) labels_size=(num_samples,1) labels= np.random.randint(0,2,size=labels_size) return train_data.astype(np.float32),labelstrain_data, labels = generate_data(11)#print(train_data)plt.plot(train_data[:,0],train_data[:,1],'sb')plt.xlabel('x')plt.ylabel('y')plt.show()运行以上程序可以得到:
输出点的坐标为[51,92],标签为0,并且将标签为0的点都设置成蓝色正方形,最终在坐标中随机的11个点都显示出的结果为下图:
我们将上面的11个点随机分成2中不同标签,分别在坐标轴中显示出来,此外由工作原理可知,我们额外添加一个绿色的圆点作为我们需要分类的点,然后训练我们的数据,并通过学习到的模型对随机生成的新点进行预测,输出结果:
# Visualize the whole datadef plot_data(all_blue,all_red): plt.figure(figsize=(10,6)) plt.scatter(all_blue[:,0], all_blue[:,1], c='b',marker='s',s=180) plt.scatter(all_red[:, 0], all_red[:, 1], c='r', marker='^', s=180) plt.plot(newcomer[:, 0], newcomer[:, 1], 'go', markersize=14) plt.xlabel('x') plt.ylabel('y') plt.show()labels.ravel() == 0 %对标签随机分配blue = train_data[labels.ravel()==0]red = train_data[labels.ravel()==1]
# Training the classifierknn = cv2.ml.KNearest_create()knn.train(train_data, cv2.ml.ROW_SAMPLE, labels)
# Predicting the label of the new data pointnewcomer, _ =generate_data(1) %此处改变需要预测的点的个数plot_data(blue,red)ret , results, neighbor , dist = knn.findNearest(newcomer,6) print('Predicted label:\t',results)print('Neighbor\'s label:\t', neighbor)print('Distance to neighbor:\t',dist)knn.setDefaultK(6) % 可设置K的个数print(knn.predict(newcomer))输出结果为:
在使用K-NN算法的时候我们不能事先知道合适的K的个数,最好的方法就是试一系列的K的值直到找到合适的,简单问题还行,到了复杂问题就不适用了。
2.用回归模型去预测连续的输出
我们用波士顿房价的经典例子来说明回归模型的问题。
首先第一步我们还是要用scikit-learn来获得数据集
import numpy as npimport cv2import matplotlib.pyplot as pltfrom sklearn import datasetsfrom sklearn import metricsfrom sklearn import model_selectionfrom sklearn import linear_modelplt.style.use('ggplot')plt.rcParams.update({'font.size':16})#download the datasetboston = datasets.load_boston() %下载数据集boston数据集总共有506个数据点,每个数据点包含了13个特征,而我们最终要预测的是房价。下载完以后我们开准备训练数据集。
第二步是设定模型:
linreg = linear_model.LinearRegression()划分训练集和测试集
X_train, X_test, y_train, y_test = model_selection.train_test_split(boston.data, boston.target , test_size = 0.1, random_state=42)
在sklearn中train的函数是fit,而在opencv中也同样适用,所以训练模型,顺便把均方差也算了出来
linreg.fit(X_train,y_train)metrics.mean_squared_error(y_train,linreg.predict(X_train))linreg.score(X_train,y_train)
最后让我们测试训练得到的模型:
y_pred = linreg.predict(X_test)metrics.mean_squared_error(y_test,y_pred)图形可视化的程序为:
plt.figure(figsize=(10,6))plt.plot(y_test,linewidth=3,label='ground truth')plt.plot(y_pred,linewidth=3,label='predicted')plt.legend(loc='best')plt.xlabel('test the data')plt.ylabel('traget vaalue')plt.show()
# 计算并输出R2分数plt.figure(figsize=(10,6))plt.plot(y_test,y_pred,'o')plt.plot([-10,60],[-10,60],'k--')plt.axis([-10,60,-10,60])plt.xlabel('ground truth')plt.ylabel('predicted')scorestr = r'R$^2$ = %.3f' % linreg.score(X_test,y_test)errstr = 'MSE = %.3f' % metrics.mean_squared_error(y_test,y_pred)plt.text(-5, 50 , scorestr, fontsize= 12)plt.text(-5, 45 , errstr, fontsize= 12)plt.show()以上程序输出结果为:
3.过拟合问题
本节主要通过过拟合问题也会影响线性模型的表现,来引出正规化(regularization)的概念和用法。主要分2大种常见的正则化L1和L2:
L1是将所有权重W的绝对值相加,L2是将所有权重W的平方相加。
代码实现方面,主要就是将第2小节中的模型语句:
linreg = linear_model.LinearRegression()
根据使用的情况改为L1正规化:
lassoreg = linear_model.Lasso()
或者是L2正规化:
ridgereg = linear_model.RidgeRegression()其他地方均与上述相似,当使用L1正规化时输出结果为:
而使用L2正规化结果为:
可以看出L2正规化下,模型表现更好。
- Machine learning for OpenCV 学习笔记 day2
- Machine learning for openCV 学习笔记 day1
- Machine learning for OpenCV 学习笔记 day4
- Machine learning for OpenCV 学习笔记 day5
- Machine learning for OpenCV 学习笔记 day6
- 《Neural Networks for Machine Learning》学习笔记
- Machine Learning 学习笔记
- machine learning 学习笔记
- Mechine learning for OpenCV 学习笔记 day3
- 机器学习笔记-advice for applying machine learning
- Stanford 机器学习笔记 Week6 Advice for Applying Machine Learning
- 【Stanford机器学习笔记】8-Advice for Applying Machine Learning
- 机器学习课程 Neural Netword for Machine Learning笔记
- Machine Learning 课程学习笔记
- 【Machine Learning】SVM学习笔记
- 学习笔记-machine learning foundaton2
- machine learning 学习笔记<一>
- machine learning学习笔记<二>
- 进程间通讯——内存映射/文件映射形式
- Excel 技巧百例:数据透视表的排序
- 20170726Python01_Python简介和输入输出
- 1701-MySQL-JDBC-连接池使用
- Breadcrumb的显示与隐藏
- Machine learning for OpenCV 学习笔记 day2
- window下mysql导入大量数据
- 工欲善其事必先利其器---开篇
- 虚拟机类加载机制
- 利用插件制作安卓动画
- Zepto.js
- java中static{}语句块详解
- 数据库选型参考资料
- 高级装配 —— Spring profile