logistics回归--梯度上升算法以及改进--用于二分类

来源:互联网 发布:xplay6知乎 编辑:程序博客网 时间:2024/05/21 11:32

1.sigmoid函数应用

  • logistics回归是用来分类的,并且属于监督学习,分类也是仅限于二分类,就是结果非0即1 (这种函数通常称作跃阶函数)
  • 这个时候就出现问题了 01之间的分界点怎么处理?
  • 引入sigmoid函数 图像见下图

sigmoid函数图像

2.算法中的数学思想

举个引例:求 函数y = -x^2+3x+1 的最大值
很简单 求得导数 y’ = -2x+3
当且仅当x=1.5时函数y取得最大值
然而并不是所有的函数都可以这么一下求出来
所以 就利用了梯度上升的方法来求极值(最大值)

# y = -x^2+3x+1# y'= -2x+3def fun(inx):    return (-2*inx+3)if __name__ == '__main__':    xNew = 2    xOld = 0    eps = 0.00001    alpha = 0.01    while abs(xOld - xNew) > eps: # 梯度上升 步长*'方向'        xOld = xNew        xNew = xNew + alpha * fun(xNew)    print xNew
  • 拓展到矩阵 就有了不一样的表达式 解释logistic中的梯度上升
  • 求出最佳回归系数的目的是为了带入到sigmoid函数
  • 每个数据点都乘上一个回归系数 带入sigmoid函数 得到0-1之间的值 实现分类目的

3.算法伪代码–梯度上升算法

每个回归系数置为1
     for 0-loopNum
     带入sigmoid计算估计值
     更新回归系数的向量
返回系数矩阵

4.缺点不足

要想得到好的结果时间花费大
容易欠拟合 精度不太高

5.改进算法-随机梯度上升算法

  • 随机选取样本来更新回归系数
  • 步长时刻改变
  • 减少了运算复杂程度

6.伪代码-随机梯度上升算法

每个回归系数置为1
     for 0-loopNum
           for 0-n所有的点
                 更新步长
                 随机选取数据点带入sigmoid计算估计值
                 更新回归系数的向量
                 标记已经随机使用的点
返回系数矩阵

# -*- coding:utf-8 -*-from numpy import *def classify(inX, theta):  # 分类器0/1二分类    res = sigm(sum(theta * inX))    if res > 0.5:        return 1    else:        return 0# 加载文件 返回listdef loadFile(fileName):    txtFile = open(fileName, "rb")    dataArr = []    labelArr = []    for line in txtFile.readlines():        lineArr = line.split()        dataArr.append([2.0, float(lineArr[0]), float(lineArr[1])])        labelArr.append(int(lineArr[-1]))    txtFile.close()    return dataArr, labelArr# sigmoid函数def sigm(inX):    return 1 / (1 + exp(-inX))# 绘图 打印出点以及直线def plotPrint(theta):    import matplotlib.pyplot as plt    dataMat, labelMat = loadFile("testSet.txt")    dataArr = array(dataMat)  # list 转换为数组    n = shape(dataArr)[0]    xcord1 = []    ycord1 = []    xcord2 = []    ycord2 = []    for i in range(n):        if int(labelMat[i]) == 1:            xcord1.append(dataArr[i, 1])            ycord1.append(dataArr[i, 2])        else:            xcord2.append(dataArr[i, 1])            ycord2.append(dataArr[i, 2])    fig = plt.figure()    ax = fig.add_subplot(111)    ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')    ax.scatter(xcord2, ycord2, s=30, c='green')    x = arange(-3.5, 3.5, 0.1)    y = (-2 * theta[0] - theta[1] * x) / theta[2]    # theta[0]*2 + theta[1]*x + theta[2]*y = 0    # print theta[0], theta[1], theta[2]    ax.plot(x, y)    plt.xlabel('x')    plt.ylabel('y')    plt.show()# 梯度上升算法 初始版def stocGradAscent0(dataMat, labelMat, alpha, loopNum):    m, n = shape(dataMat)    theta = ones((n, 1))  # 初始为n*1的回归参数矩阵    for i in range(loopNum):        h = sigm(dataMat * theta)  # z = w0x0+w1x1+...+wnxn        loss = labelMat - h  # y与y拔之差 预测值与真实值的差        theta = theta + alpha * dataMat.T * loss  # alpha * 3*1 (3*100 * 100*1) 矩阵的运算 不是一步运算    return theta# 随即梯度上升算法 改进版 一维array数值运算 不涉及矩阵的运算def stocGradAscent(dataMat, labelMat, loopNum=150):    m, n = shape(dataMat)    theta = ones(n)    for i in range(loopNum):  # 循环次数        dataIndex = range(m)  # 下标存放的数组        for j in range(m):            alpha = 4 / (1.0 + j + i) + 0.01            randomIndex = int(random.uniform(0, len(dataIndex)))  # 随机下标            h = sigm(sum(dataMat[randomIndex] * theta))  # sum(3L*3L) = 3L            loss = labelMat[randomIndex] - h  # 误差            theta = theta + alpha * loss * dataMat[randomIndex]            del (dataIndex[randomIndex])  # 用过之后删除该下标    return thetaif __name__ == '__main__':    Dir = "testSet.txt"    # 步长 循环次数    alpha = 0.001    loopNum = 1000    # 数据矩阵 类别(标签)矩阵    data, label = loadFile(Dir)    data = mat(data)    label = mat(label)    # th就是回归参数矩阵    th = stocGradAscent0(data, label.T, alpha, 500)    plotPrint(th.A)  # th.getA() th.A得到ndarray数组    print "----" * 10    data, label = loadFile(Dir)    theta = stocGradAscent(array(data), label, 500)    plotPrint(theta)    print theta    testPoint = mat(ones((3, 1)))    print classify(testPoint, theta)

附上程序运行结果
梯度上升

随机梯度上升

'''----------------------------------------[ 9.37024571  1.33192559 -2.62149254]1'''

配上数据集 100*3

'dataSet.txt'"""-0.017612   14.053064   0-1.395634   4.662541    1-0.752157   6.538620    0-1.322371   7.152853    00.423363    11.054677   00.406704    7.067335    10.667394    12.741452   0-2.460150   6.866805    10.569411    9.548755    0-0.026632   10.427743   00.850433    6.920334    11.347183    13.175500   01.176813    3.167020    1-1.781871   9.097953    0-0.566606   5.749003    10.931635    1.589505    1-0.024205   6.151823    1-0.036453   2.690988    1-0.196949   0.444165    11.014459    5.754399    11.985298    3.230619    1-1.693453   -0.557540   1-0.576525   11.778922   0-0.346811   -1.678730   1-2.124484   2.672471    11.217916    9.597015    0-0.733928   9.098687    0-3.642001   -1.618087   10.315985    3.523953    11.416614    9.619232    0-0.386323   3.989286    10.556921    8.294984    11.224863    11.587360   0-1.347803   -2.406051   11.196604    4.951851    10.275221    9.543647    00.470575    9.332488    0-1.889567   9.542662    0-1.527893   12.150579   0-1.185247   11.309318   0-0.445678   3.297303    11.042222    6.105155    1-0.618787   10.320986   01.152083    0.548467    10.828534    2.676045    1-1.237728   10.549033   0-0.683565   -2.166125   10.229456    5.921938    1-0.959885   11.555336   00.492911    10.993324   00.184992    8.721488    0-0.355715   10.325976   0-0.397822   8.058397    00.824839    13.730343   01.507278    5.027866    10.099671    6.835839    1-0.344008   10.717485   01.785928    7.718645    1-0.918801   11.560217   0-0.364009   4.747300    1-0.841722   4.119083    10.490426    1.960539    1-0.007194   9.075792    00.356107    12.447863   00.342578    12.281162   0-0.810823   -1.466018   12.530777    6.476801    11.296683    11.607559   00.475487    12.040035   0-0.783277   11.009725   00.074798    11.023650   0-1.337472   0.468339    1-0.102781   13.763651   0-0.147324   2.874846    10.518389    9.887035    01.015399    7.571882    0-1.658086   -0.027255   11.319944    2.171228    12.056216    5.019981    1-0.851633   4.375691    1-1.510047   6.061992    0-1.076637   -3.181888   11.821096    10.283990   03.010150    8.401766    1-1.099458   1.688274    1-0.834872   -1.733869   1-0.846637   3.849075    11.400102    12.628781   01.752842    5.468166    10.078557    0.059736    10.089392    -0.715300   11.825662    12.693808   00.197445    9.744638    00.126117    0.922311    1-0.679797   1.220530    10.677983    2.556666    10.761349    10.693862   0-2.168791   0.143632    11.388610    9.341997    00.317029    14.739025   0"""

    • sigmoid函数应用
    • 算法中的数学思想
    • 算法伪代码梯度上升算法
    • 缺点不足
    • 改进算法-随机梯度上升算法
    • 伪代码-随机梯度上升算法