Scikit-learn——SVM

来源:互联网 发布:unity3d像素游戏制作 编辑:程序博客网 时间:2024/05/17 23:19

1. Support Vector Machines

支持向量机(SVM)是一组用于分类(classification), 回归(regression)和异常值检测(outliers detection)的监督学习方法。

支持向量机的优点是:

  • 在高维空间有效。
  • 在维度数量大于样本数量的情况下仍然有效。
  • 在决策功能(称为支持向量)中使用训练点的子集,因此它也是内存有效的。
  • 多功能:可以为决策功能指定不同的内核函数。提供通用内核,但也可以指定自定义内核。

支持向量机的缺点包括:

  • 如果特征数量远远大于样本数量,则该方法可能会导致较差的性能。
  • 支持向量机不直接提供概率估计,这些是使用昂贵的五-折交叉验证计算的(参见下面的分数和概率)。

2. Kernel function

  • linear:x,x
  • polynomial:(γx,x+r)d,
  • rbf:exp(γ||xx||2),
  • sigmoid:(tanh(γx,x+r)

注:d用参数degree表示,r用参数coef0表示;γ>0用参数gamma表示;

3. Classification

在sklean的支持向量机中SVC, NuSVC和LinearSVC都可以用来(多)分类,只是在目标函数以及参数上有略微的区别。从另一方面来说,LinearsSVC是支持向量机在线性核方面的应用,也就是说它只支持线性核,所以它并不没有kernel这个参数选项,默认就是线性。

这里写图片描述

3.1 sklearn.svm.SVC

SVC(Support Vector Classificaion)是基于libsvm来实现的。其拟合的时间复杂度大于(O(n2)),当训练集大于10000时就很难计算了。同时SVC也支持多分类,其思想依旧是采用的one-vs-one。

其目标函数为:

minw,b,ξs.t.12wTw+Ci=1mξiy(i)(wTϕ(i)+b)1ξi,i=1,2,...mξi0,i=1,2,...m

下面是初始化类SVC时所需要用到的参数,也就是说这些参数决定这训练出来的模型。

SVC(C=1.0, kernel=’rbf’, degree=3, gamma=’auto’, coef0=0.0, shrinking=True, probability=False, tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape=’ovr’, random_state=None)
参数表 含义 C 惩罚项;浮点型,其值为可选项,默认为1.0 Kernel 核函数类型;字符串型,其值为可选项,默认值为’rbf’ degree 多项式次数;整型,默认值为3,仅在多项式中有效 gamma 核函数的系数;浮点型,默认值为(1/n_features) coef0 常数项;浮点型,默认值为0.0,仅在多项式和sigmoid核函数中有效 decision_function_shape 分类方式;字符串型,默认值为’ovr’ 其余参数目前没有用到,故暂不深究,默认值即可


SVC类中的常用属性

属性表 含义 support_ 所有支持向量的下标 support_vectors_ 所有的支持向量 n_support_ 每个类别支持向量的个数 intercept_ 截距


SVC类中的常用方法

方法表 含义 fit(X,y) 用给定的X,y来训练模型 predit(X) 给定X,预测类别 score(X,y) 计算准确率 get_param(True) 返回参数表中的所有参数


以上所有的原型戳此处 API接口地址


3.1.1 examples

3.1.1.1 binary-class classification

这里写图片描述

如图为可视化后的数据集,下面用SVC来训练模型。

X,y = other.loadDataSet('ex6data1.txt') # 载入数据集
#可视化数据idx_0 = np.where(y_data==0) #找出所有标签值为0的下标索引p0 = plt.scatter(X_data[idx_0,0],X_data[idx_0,1],                 marker='*',color='r',label='0',s=50) #画出所有标签值为0的点idx_1 = np.where(y_data==1)p0 = plt.scatter(X_data[idx_1,0],X_data[idx_1,1],                 marker='o',color='b',label='1',s=50)plt.legend(loc = 'upper right')plt.show()
# 将数据集划分成训练集和验证集X_train,X_test,y_train,y_test = train_test_split(        X_data,y_data,test_size = 0.3) 
# 用训练集和验证集来训练超参数kernel_type = 'linear'C,gamma = other.parameter(X_train,y_train,X_test,y_test                          ,kernel_type)
# 0.3 和 1 是我已经训练好的,所以就直接写上了svc = svm.SVC(C = 0.3,gamma= 1,kernel = kernel_type) # 赋初值svc.fit(X_train,y_train)# 此时svc 表示训练好的向量机模型other.visualize(X_train,y_train,svc,kernel_type)# 可视化训练好的结果print svc.score(X_test,y_test) #准确率print svc.support_vectors_ #支持向量

这里写图片描述

可以看出,此时训练出的模型为软间隔。下图示我将C设置成300变成硬间隔的结果:

这里写图片描述

顺便补充说一句:SVM中,C的作用就类似于线性回归和逻辑回归中1λ的作用;也就是说,在支持向量机中C特别大,则模型就变成了硬间隔;

源码地址

3.1.1.2 multi-class classification

一般常见的多分类(multi-class classification)策略主要有one-vs-one和one-vs-all;先大致说一下两者的主要思想:

如图,有4个类别
这里写图片描述

(1) one-vs-all

所谓’ova’就是:假如有n类,则每次都将其中1类看做是“正类”,其余的n-1当作“负类”;通过训练n个分类器来达到分类的目的。

这里写图片描述

如上图,通过训练4个SVM分类器即可达到目的,每次将红色的作为“正类”,蓝色的作为“负类”。

(2) one-vs-one

‘ovo’就是:假如有n类,则每次选取其中的2类,然后对其训练得到一个分类器;然后在继续选择两个,以此类推,但任意两次选择的都不同,一共训练n(n-1)/2个分类器即可达到分类目的。
这里写图片描述

如上图,通过训练6个分类器即可

在SVC和Nusvc中都可以通过参数decision_function_shape来选择哪种方法进行分类。

下图为可视化的数据集,可以看到有3类

这里写图片描述

iris = datasets.load_iris()X = iris.datay = iris.target #载入数据集X_train,X_test,y_train,y_test = train_test_split(        X,y,test_size = 0.3) #将数据集分割成训练集和验证集#  通过 ovo 的方法进行训练svc_ovo = svm.SVC(decision_function_shape = 'ovo')svc_ovo.fit(X_train,y_train)print 'by one vs one: '+ str(svc_ovo.score(X_test,y_test))#  通过 ova 的方法进行训练svc_ova = svm.SVC(decision_function_shape = 'ova')svc_ova.fit(X_train,y_train)print 'by one vs all: '+ str(svc_ova.score(X_test,y_test))

源码地址

3.2 sklearn.svm.NuSVC

Nu-Support Vector Classification(NuSVC)总体上同SVC一样,都是基于libsvm库来实现的,只是多了一个控制支持向量个数的参数(并不直接是支持向量的个数)。即nu,其默认值是0.5。所以就不准备继续唠叨。

3.3 sklearn.svm.LinearSVC

Linear Support Vector Classification(LinearSVC),可以看作是SVC中将kernel设置为’linear’的版本。但区别是,LinearSVC是基于liblinear来实现的;因此,在选择惩罚项和损失函数方面有更大的灵活性,对于大规模的数据集有更强的扩展性。

注:以上都是scikit-learn 0.18.2版本中的用法,目前最新的是0.19.1,里面有些许的改动

参考:

  • scikit-learn.org
原创粉丝点击