Svm算法理解以及MLlib实现

来源:互联网 发布:淘宝指纹支付怎么设置 编辑:程序博客网 时间:2024/05/21 12:48

首先SVM算法它也是一种分类算法,类似于贝叶斯分类算法,但是在底层的实现还是不同,它可以用更少的样本,训练出更高精度的模型。

支持向量机(Support Vector Machine)是Cortes和Vapnik于1995年首先提出的,
它在解决小样本、非线性及高维模式识别中表现出许多特有的优势,并能够推广应用到函数拟合等其他机器学习问题中
支持向量机方法是建立在统计学习理论的VC 维理论和结构风险最小原理基础上的,根据有限的样本信息在模型的复杂性(
即对特定训练样本的学习精度,Accuracy)
和学习能力(即无错误地识别任意样本的能力)之间寻求最佳折衷,以期获得最好的推广能力(或称泛化能力)。

svm特点:  SVM正是努力最小化结构风险的算法。  SVM其他的特点就比较容易理解了。  小样本,并不是说样本的绝对数量少(实际上,对任何算法来说,更多的样本几乎总是能带来更好的效果),      而是说与问题的复杂度比起来,SVM算法要求的样本数是相对比较少的。 非线性,是指SVM擅长应付样本数据线性不可分的情况,主要通过松弛变量(也有人叫惩罚变量)和核函数技术来实现,     这一部分是SVM的精髓,以后会详细讨论。多说一句,关于文本分类这个问题究竟是不是线性可分的,尚没有定论,     因此不能简单的认为它是线性可分的而作简化处理,在水落石出之前,只好先当它是线性不可分的     (反正线性可分也不过是线性不可分的一种特例而已,我们向来不怕方法过于通用)。高维模式识别是指样本维数很高,例如文本的向量表示,如果没有经过另一系列文章(《文本分类入门》)    中提到过的降维处理,出现几万维的情况很正常,其他算法基本就没有能力应付了,    SVM却可以,主要是因为SVM 产生的分类器很简洁,用到的样本信息很少(仅仅用到那些称之为“支持向量”的样本,此为后话),    使得即使样本维数很高,也不会给存储和计算带来大麻烦(相对照而言,kNN算法在分类时就要用到所有样本,样本数巨大,    每个样本维数再一高,这日子就没法过了……)。

调用MLlib包中的svm算法:
1,训练数据:

    1,1,1    1,1,1    0,1,0    0,0,1    1,0,1    1,1,0    0,0,0    0,0,0    1,0,0    1,1,1

第一个数,是代表数据的类别,第二,三个则是代表其特征。

2,预测数据:

        1,1        0,1        1,0        1,1        0,0

这两数,都是代表其特征。

Scala调用代码:

import org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.{SparkContext, SparkConf}import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}import org.apache.spark.mllib.evaluation.BinaryClassificationMetricsimport org.apache.spark.mllib.util.MLUtilsimport org.apache.spark.mllib.regression.LinearRegressionModelimport org.apache.spark.mllib.regression.LinearRegressionWithSGD/** * Created by Administrator on 2016/8/8. */object MySVM {  def main(args: Array[String]) {  val conf =new SparkConf().setAppName("TestSVM").setMaster("local");  val sc = new SparkContext(conf)    val data = sc.textFile("file///F:/1/newtest3.txt")  //训练数据    val parsedData = data.map { line =>      val parts = line.split(',')      val row1=parts(0).toDouble      val row2=parts.drop(1).map(_.toDouble)      println("row1=="+row1)      LabeledPoint(row1, Vectors.dense(row2))    }    val splits=parsedData.randomSplit(Array(0.9,0.1), seed = 11L)    val training = splits(0).cache()    val test = splits(1)    //设置迭代次数    val numIterations = 100    //将训练数据训练成模型    val model = SVMWithSGD.train(training, numIterations)    // 清空模型入口    model.clearThreshold()    // Compute raw scores on the test set.    val scoreAndLabels = test.map { point =>      val score = model.predict(point.features)      (score, point.label)    }    // 计算原始测试集的成绩。    val metrics = new BinaryClassificationMetrics(scoreAndLabels)    val auROC = metrics.areaUnderROC()    println("Area under ROC = " + auROC)    val forecastdata = sc.textFile("file///F:/1/newtest2.txt")    //需要预测的分类数据    val forecastresult=forecastdata.map{line=>      val parts=line.split(",")      Vectors.dense(parts.map(_.toDouble))    }    val forecastresult1= model.predict(forecastresult)//使用模型计算    println("---------------------结果来了-------------------------------")    forecastresult1.foreach(println)  }}
0 0
原创粉丝点击