线性可分情形下支持向量机学习的SMO算法
来源:互联网 发布:淘宝同学app 编辑:程序博客网 时间:2024/06/09 05:15
最近一直谋划着发一篇有关SVM的帖子,为此很准备了几天时间。相对于基于决策树的分类模型来说,SVM在数学理论上
有一定难度,曾经看过一篇帖子,上面说理解SVM有三层境界,第一层是了解SVM,第二层是深入SVM,第三层是证明
SVM,要想达到第三层境界需要对泛函分析和最优化理论比较熟悉,就本人的专业背景来说,经过这几天的深入学习,应该说
自己已经达到了第二层境界。
本文主要参考李航所著的《统计学习方法》一书,下面开始进入正题:
求解上述优化模型有多种方法,最流行的是Platt提出的SMO算法,该算法有两层循环构成,外层循环遍历非边界样本或所有样本:优先遍历非边界样本,对其中违背KKT条件的样本进行调整,直至非边界样本全部满足KKT条件,若某次遍历发现没有非边界样本得到调整,就遍历所有样本,以检验是否有样本违背KKT条件,若有样本得到调整,则下次循环有必要再次遍历非边界样本。这样
外层循环在“遍历所有样本”和“遍历非边界样本”之间切换,直至所与样本都满足KKT条件。
内层循环针对违背KKT条件的样本寻找配对样本,优先选择使得|E2-E1|最大的样本作为配对样本,若上述过程失败则随机地选择非边界样本进行优化,若这样的样本找不到,则随机地选择任意样本进行优化,若仍然失败,则进入下一轮外层循环。
寻找配对样本以及参数更新的过程在《统计学习方法》一书中有详细描述,见下图
在更新参数之后需要及时更新超平面的参数b,更新公式如下:
以下SMO算法的伪代码摘自Platt的论文《Sequential Minimal Optimization - A Fast Algorithm for Training Support Vector Machines》,
基于上述伪代码,本文采用Python来编写代码,所有代码位于文件svm.py中,
from __future__ import divisionimport numpy as npclass SvmClassifier: def __init__(self, X, y, C, tol):#初始化支持<span style="font-family:SimSun;">向量机</span> self.X = X self.y = y self.C = C self.tol = tol self.m = X.shape[0] self.alphas = np.zeros((self.m, 1)) self.b = 0 self.errors = np.zeros((self.m, 1)) def getError(self, i):#<span style="font-family:SimSun;">计算误差</span> Ei = (self.alphas*self.y*(np.dot(self.X, self.X[i:i+1,:].T))).sum() + self.b - self.y[i,0] return Ei def selectAnotherAlpha(self, i):#选择配对的第二个样本 maxDelta = 0 for k in range(self.m): if k == i:continue delta = abs(self.errors[k,0] - self.errors[i,0]) if (delta>maxDelta): maxDelta = delta j = k if (maxDelta>0): return j else: j = int(np.random.randint(self.m)) while (j==i): j = int(np.random.randint(self.m)) return j def updateErrors(self, i):#当更新self.alphas和self.b之后,必须要更新<span style="font-family:SimSun;">误差</span> self.errors[i,0] = self.getError(i) def updateSvm(self, i):#选择配对样本并更新支持向量机的关键参数 Ei = self.errors[i,0] yi = self.y[i,0] if ((yi*Ei < -self.tol) and (self.alphas[i,0] < self.C)) or \ ((yi*Ei > self.tol) and (self.alphas[i,0] > 0)): j =self.selectAnotherAlpha(i) Ej = self.errors[j,0] yj = self.y[j,0] aiOld = self.alphas[i,0].copy();ajOld = self.alphas[j,0].copy() if (yi != yj): L = max(0, ajOld-aiOld) H = min(self.C, self.C+ajOld-aiOld) else: L = max(0, ajOld+aiOld-self.C) H = min(self.C, ajOld+aiOld) if (L == H): return 0 eta = (self.X[i,:]*self.X[i,:]).sum()+(self.X[j,:]*self.X[j,:]).sum()\ - 2*(self.X[i,:]*self.X[j,:]).sum() if (eta <= 0):return 0 self.alphas[j,0] += self.y[j,0]*(Ei-Ej)/eta if self.alphas[j,0]>H: self.alphas[j,0] = H elif self.alphas[j,0]<L: self.alphas[j,0] = L self.updateErrors(j) if abs(self.alphas[j,0]-ajOld)<0.00001: return 0 self.alphas[i,0] = aiOld + self.y[i,0]*self.y[j,0]*(ajOld-self.alphas[j,0]) self.updateErrors(i) b1 = self.b - self.errors[i,0] -\ self.y[i,0]*(self.X[i,:]*self.X[i,:]).sum()*(self.alphas[i,0]-aiOld) -\ self.y[j,0]*(self.X[i,:]*self.X[j,:]).sum()*(self.alphas[j,0]-ajOld) b2 = self.b - self.errors[j,0] -\ self.y[j,0]*(self.X[j,:]*self.X[j,:]).sum()*(self.alphas[j,0]-ajOld) -\ self.y[i,0]*(self.X[i,:]*self.X[j,:]).sum()*(self.alphas[i,0]-aiOld) if ((self.alphas[i,0]>0) and (self.alphas[i,0]<self.C)): self.b = b1 elif ((self.alphas[j,0]>0) and (self.alphas[j,0]<self.C)): self.b = b2 else: self.b = (b1+b2)/2 for k in range(self.m): self.updateErrors(k) return 1 else:return 0 def train(self, maxIter):#训练模型 for k in range(self.m): self.updateErrors(k) k = 0 examineAll = True pairsChanged = 0 while (k<maxIter) and ((pairsChanged>0) or (examineAll)): pairsChanged = 0 if examineAll: for i in range(self.m): pairsChanged += self.updateSvm(i) k += 1 else: nonBoundIs = np.nonzero((self.alphas>0)&(self.alphas<self.C))[0] for i in nonBoundIs: pairsChanged += self.updateSvm(i) k += 1 if examineAll: examineAll = False elif (pairsChanged ==0): examineAll = True def predict(self, observation):#预测未知样本的类别 w = np.zeros((1, self.X.shape[1])) for i in range(self.m): w = w +self.alphas[i,0]*self.y[i,0]*self.X[i:i+1,:] label = np.sign((w*observation).sum()+self.b) return label为了检验上述代码的运行效果,本文使用了《机器学习实战》第六章使用的一个近似线性可分的数据集,原始数据位于一个文本文件
testSet.txt中,导入数据的代码如下:
dataSet = [] labels = [] fileIn = open('/home/liujun/workspace/machine_learning/calssification/testSet.txt') for line in fileIn.readlines(): lineArr = line.strip().split('\t') dataSet.append([float(lineArr[0]), float(lineArr[1])]) labels.append(float(lineArr[2]))rowRange = range(100)np.random.shuffle(rowRange)#将<span style="font-family:SimSun;">训练数据索引按照均匀分布原则随机排列</span>training_data = [data[i] for i in rowRange[0:80]]#取80%的数据为训练数据testing_data = [data[i] for i in rowRange[80:100]]#取<span style="font-family:SimSun;">2</span>0%的数据为训练数据training_data_labels = [labels[i] for i in rowRange[0:80]]testing_data_labels = [labels[i] for i in rowRange[80:100]]X = np.array(training_data)y = np.array(training_data_labels).reshape((80,1))X_test = np.array(testing_data)y_test = np.array(testing_data_labels).reshape((20,1))由于输入观测向量是二维的,可以直观看一下训练数据X是不是近似线性可分的,画图代码如下
import matplotlib.pyplot as pltfor i in range(X.shape[0]): if y[i,0] == -1: plt.plot(X[i,0], X[i,1], 'or') else: plt.plot(X[i,0], X[i,1], 'Db')其中红色点标签为-1,蓝色点标签为+1,见下图:
以下代码计算分类器在测试数据集X_test上预测类别,并和y_test比较计算错判率,代码如下:
y_predict = np.zeros((20,1))for i in range(y_predict.shape[0]): y_predict[i,0] = classifier.predict(X_test[i:i+1,:])errorVector = np.zeros((20,1))errorVector[y_predict != y_test] = 1<span style="font-family:SimSun;">errorRate = </span>errorVector.sum()/errorVector.shape[0]#计算错判率运行结果为错判率errorRate的值为0.1,效果不错.
到此,本文基本结束了。希望接下来有时间将本文代码改进一下以推广到训练数据非线性可分情形。
- 线性可分情形下支持向量机学习的SMO算法
- 线性可分情形下支持向量机学习的SMO算法
- python 机器学习 支持向量机 线性可分
- OpenCv学习笔记--支持向量机SVM线性可分情况下的OpenCv实现的超详细注释(2)
- SVM-支持向量机学习(1):线性可分SVM的基本型
- SVM-支持向量机学习(2):线性可分SVM的对偶型
- 支持向量机SMO算法
- 支持向量机smo算法
- 支持向量机 smo算法
- 深入浅出机器学习之支持向量机SVM(SMO算法)
- SVM 支持向量机(1)--- 完全线性可分
- SVM 支持向量机(2)---不完全线性可分
- 深入浅出机器学习之支持向量机SVM2(线性可分部分)
- 【Python学习系列八】Python实现线性可分SVM(支持向量机)
- SVM支持向量机(SMO算法)的R实现
- 支持向量机(SVM)的SMO算法详解
- 机器学习笔记3:支持向量机的SMO高效优化算法
- 支持向量机(五)SMO算法
- 阿里靠增发股份用钱或难挡离职潮!
- 杭电4324 Triangle LOVE(拓扑排序)
- A35号上机作业
- 新的开始,学习记录
- 利用mkcd制作AIX系统恢复光盘
- 线性可分情形下支持向量机学习的SMO算法
- 上海二手挖掘机市场,二手挖掘机价格优惠
- 很久没回来写博客了。
- ubuntu server设置时区和更新时间
- #pragma once 与 #ifndef 解析
- 织梦dedecms网站六大SEO优化技巧分享
- UVA - 10602 Editor Nottoobad
- OpenMP 作业
- Android USB tethering相关代码