【Python学习系列八】Python实现线性可分SVM(支持向量机)

来源:互联网 发布:视频dvd刻录软件 编辑:程序博客网 时间:2024/05/30 02:23

1、运行环境:eclipse+pydev+Anaconda2-4.4.0(python2.7),含numpy、matplotlib(制图)。

2、代码:

# -*- coding: utf-8 -*-__author__ = 'Jason.F'from numpy import *import matplotlib.pyplot as pltimport operatorimport time#导入数据,格式: value1 value2 label#3.542485    1.977398    -1#3.018896    2.556416    1def loadDataSet(fileName):    dataMat = []    labelMat = []    with open(fileName) as fr:        for line in fr.readlines():            lineArr = line.strip().split()            labelMat.append(float(lineArr[2]))            #i=lineArr.__len__()            #for i in range(1,i):            dataMat.append([float(lineArr[0]),float(lineArr[1])])         return dataMat, labelMatdef selectJrand(i, m):    j = i    while (j == i):        j = int(random.uniform(0, m))    return jdef clipAlpha(aj, H, L):    if aj > H:        aj = H    if L > aj:        aj = L    return ajclass optStruct:    def __init__(self, dataMatIn, classLabels, C, toler):        self.X = dataMatIn        self.labelMat = classLabels        self.C = C        self.tol = toler        self.m = shape(dataMatIn)[0]        self.alphas = mat(zeros((self.m, 1)))        self.b = 0        self.eCache = mat(zeros((self.m, 2)))def calcEk(oS, k):    fXk = float(multiply(oS.alphas, oS.labelMat).T * (oS.X * oS.X[k, :].T)) + oS.b    Ek = fXk - float(oS.labelMat[k])    return Ekdef selectJ(i, oS, Ei):    maxK = -1    maxDeltaE = 0    Ej = 0    oS.eCache[i] = [1, Ei]    validEcacheList = nonzero(oS.eCache[:, 0].A)[0]    if (len(validEcacheList)) > 1:        for k in validEcacheList:            if k == i:                continue            Ek = calcEk(oS, k)            deltaE = abs(Ei - Ek)            if (deltaE > maxDeltaE):                maxK = k                maxDeltaE = deltaE                Ej = Ek        return maxK, Ej    else:        j = selectJrand(i, oS.m)        Ej = calcEk(oS, j)    return j, Ejdef updateEk(oS, k):    Ek = calcEk(oS, k)    oS.eCache[k] = [1, Ek]def innerL(i, oS):    Ei = calcEk(oS, i)    if ((oS.labelMat[i] * Ei < -oS.tol) and (oS.alphas[i] < oS.C)) or ((oS.labelMat[i] * Ei > oS.tol) and (oS.alphas[i] > 0)):        j, Ej = selectJ(i, oS, Ei)        alphaIold = oS.alphas[i].copy()        alphaJold = oS.alphas[j].copy()        if (oS.labelMat[i] != oS.labelMat[j]):            L = max(0, oS.alphas[j] - oS.alphas[i])            H = min(oS.C, oS.C + oS.alphas[j] - oS.alphas[i])        else:            L = max(0, oS.alphas[j] + oS.alphas[i] - oS.C)            H = min(oS.C, oS.alphas[j] + oS.alphas[i])        if (L == H):            # print("L == H")            return 0        eta = 2.0 * oS.X[i, :] * oS.X[j, :].T - oS.X[i, :] * oS.X[i, :].T - oS.X[j, :] * oS.X[j, :].T        if eta >= 0:            # print("eta >= 0")            return 0        oS.alphas[j] -= oS.labelMat[j] * (Ei - Ej) / eta        oS.alphas[j] = clipAlpha(oS.alphas[j], H, L)        updateEk(oS, j)        if (abs(oS.alphas[j] - alphaJold) < 0.00001):            # print("j not moving enough")            return 0        oS.alphas[i] += oS.labelMat[j] * oS.labelMat[i] * (alphaJold - oS.alphas[j])        updateEk(oS, i)        b1 = oS.b - Ei - oS.labelMat[i] * (oS.alphas[i] - alphaIold) * oS.X[i, :] * oS.X[i, :].T - oS.labelMat[j] * (oS.alphas[j] - alphaJold) * oS.X[i, :] * oS.X[j, :].T        b2 = oS.b - Ei - oS.labelMat[i] * (oS.alphas[i] - alphaIold) * oS.X[i, :] * oS.X[j, :].T - oS.labelMat[j] * (oS.alphas[j] - alphaJold) * oS.X[j, :] * oS.X[j, :].T        if (0 < oS.alphas[i]) and (oS.C > oS.alphas[i]):            oS.b = b1        elif (0 < oS.alphas[j]) and (oS.C > oS.alphas[j]):            oS.b = b2        else:            oS.b = (b1 + b2) / 2.0        return 1    else:        return 0def smoP(dataMatIn, classLabels, C, toler, maxIter, kTup=('lin', 0)):    """    输入:数据集, 类别标签, 常数C, 容错率, 最大循环次数    输出:目标b, 参数alphas    """    oS = optStruct(mat(dataMatIn), mat(classLabels).transpose(), C, toler)    iterr = 0    entireSet = True    alphaPairsChanged = 0    while (iterr < maxIter) and ((alphaPairsChanged > 0) or (entireSet)):        alphaPairsChanged = 0        if entireSet:            for i in range(oS.m):                alphaPairsChanged += innerL(i, oS)            # print("fullSet, iter: %d i:%d, pairs changed %d" % (iterr, i, alphaPairsChanged))            iterr += 1        else:            nonBoundIs = nonzero((oS.alphas.A > 0) * (oS.alphas.A < C))[0]            for i in nonBoundIs:                alphaPairsChanged += innerL(i, oS)#内积                # print("non-bound, iter: %d i:%d, pairs changed %d" % (iterr, i, alphaPairsChanged))            iterr += 1        if entireSet:            entireSet = False        elif (alphaPairsChanged == 0):            entireSet = True        # print("iteration number: %d" % iterr)    return oS.b, oS.alphasdef calcWs(alphas, dataArr, classLabels):    """    输入:alphas, 数据集, 类别标签    输出:目标w    """    X = mat(dataArr)    labelMat = mat(classLabels).transpose()    m, n = shape(X)    w = zeros((n, 1))    for i in range(m):        w += multiply(alphas[i] * labelMat[i], X[i, :].T)    return wdef plotFeature(dataMat, labelMat, weights, b):    dataArr = array(dataMat)    n = shape(dataArr)[0]    xcord1 = []; ycord1 = []    xcord2 = []; ycord2 = []    for i in range(n):        if int(labelMat[i]) == 1:            xcord1.append(dataArr[i, 0])            ycord1.append(dataArr[i, 1])        else:            xcord2.append(dataArr[i, 0])            ycord2.append(dataArr[i, 1])    fig = plt.figure()    ax = fig.add_subplot(111)    ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')    ax.scatter(xcord2, ycord2, s=30, c='green')    x = arange(2, 7.0, 0.1)    y = (-b[0, 0] * x) - 10 / linalg.norm(weights)    ax.plot(x, y)    plt.xlabel('X1'); plt.ylabel('X2')    plt.show()def main():    trainDataSet, trainLabel = loadDataSet('D:\set.txt')    b, alphas = smoP(trainDataSet, trainLabel, 0.6, 0.0001, 40)    ws = calcWs(alphas, trainDataSet, trainLabel)    print("ws = \n", ws)    print("b = \n", b)    plotFeature(trainDataSet, trainLabel, ws, b)if __name__ == '__main__':    start = time.clock()    main()    end = time.clock()    print('finish all in %s' % str(end - start))

3、set.txt样例数据

3.542485    1.977398    -13.018896    2.556416    -17.551510    -1.580030   12.114999    -0.004466   -18.127113    1.274372    17.108772    -0.986906   18.610639    2.046708    12.326297    0.265213    -13.634009    1.730537    -10.341367    -0.894998   -13.125951    0.293251    -12.123252    -0.783563   -10.887835    -2.797792   -17.139979    -2.329896   11.696414    -1.212496   -18.117032    0.623493    18.497162    -0.266649   14.658191    3.507396    -18.197181    1.545132    11.208047    0.213100    -11.928486    -0.321870   -12.175808    -0.014527   -17.886608    0.461755    13.223038    -0.552392   -13.628502    2.190585    -17.407860    -0.121961   17.286357    0.251077    12.301095    -0.533988   -1-0.232542   -0.547690   -13.457096    -0.082216   -13.023938    -0.057392   -18.015003    0.885325    18.991748    0.923154    17.916831    -1.781735   17.616862    -0.217958   12.450939    0.744967    -17.270337    -2.507834   11.749721    -0.961902   -11.803111    -0.176349   -18.804461    3.044301    11.231257    -0.568573   -12.074915    1.410550    -1-0.743036   -1.736103   -13.536555    3.964960    -18.410143    0.025606    17.382988    -0.478764   16.960661    -0.245353   18.234460    0.701868    18.168618    -0.903835   11.534187    -0.622492   -19.229518    2.066088    17.886242    0.191813    12.893743    -1.643468   -11.870457    -1.040420   -15.286862    -2.358286   16.080573    0.418886    12.544314    1.714165    -16.016004    -3.753712   10.926310    -0.564359   -10.870296    -0.109952   -12.369345    1.375695    -11.363782    -0.254082   -17.279460    -0.189572   11.896005    0.515080    -18.102154    -0.603875   12.529893    0.662657    -11.963874    -0.365233   -18.132048    0.785914    18.245938    0.372366    16.543888    0.433164    1-0.236713   -5.766721   -18.112593    0.295839    19.803425    1.495167    11.497407    -0.552916   -11.336267    -1.632889   -19.205805    -0.586480   11.966279    -1.840439   -18.398012    1.584918    17.239953    -1.764292   17.556201    0.241185    19.015509    0.345019    18.266085    -0.230977   18.545620    2.788799    19.295969    1.346332    12.404234    0.570278    -12.037772    0.021919    -11.727631    -0.453143   -11.979395    -0.050773   -18.092288    -1.372433   11.667645    0.239204    -19.854303    1.365116    17.921057    -1.327587   18.500757    1.492372    11.339746    -0.291183   -13.107511    0.758367    -12.609525    0.902979    -13.263585    1.367898    -12.912122    -0.202359   -11.731786    0.589096    -12.387003    1.573131    -1

4、执行结果:

('ws = \n', array([[ 0.65307162],       [-0.17196128]]))('b = \n', matrix([[-2.89901748]]))finish all in 19.5581056613


阅读全文
0 0
原创粉丝点击