感知机

来源:互联网 发布:知乎页面显示不正常 编辑:程序博客网 时间:2024/05/01 09:10

感知机学习算法的对偶形式:http://blog.csdn.net/qq_29591261/article/details/77945561

本文省略推理过程,直接上算法过程和代码,想看原理推导的请参考原文:http://www.hankcs.com/ml/the-perceptron.html


概念

感知机是二分类模型,输入实例的特征向量,输出实例的±类别。


定义

假设输入空间是这里写图片描述,输出空间是这里写图片描述,x和y分属这两个空间,那么由输入空间到输出空间的如下函数:
这里写图片描述
称为感知机。其中,w和b称为感知机模型参数,这里写图片描述叫做权值或权值向量,这里写图片描述叫做偏置,w·x表示向量w和x的内积。sign是一个函数:
这里写图片描述
感知机的几何解释是,线性方程
这里写图片描述
将特征空间划分为正负两个部分:
这里写图片描述


感知机学习算法

这里写图片描述


感知机算法代码

为了理解起来更容易些,我自己加了一部分注释。

  1. # -*-coding:utf-8 -*-
  2. #Filename: train2.1.py
  3. # Authorhankcs
  4. # Date:2015/1/30 16:29
  5. import copy
  6. from matplotlib import pyplot as plt
  7. from matplotlib import animation
  8. training_set= [[(1, 2), 1], [(2, 3), 1], [(3, 1), -1], [(4, 2), -1]]    #我在这里修改了原文数据,原数据让线看起来像平移
  9. = [0, 0]    #参数初始化
  10. = 0
  11. history = []    #用来记录每次更新过后的w,b
  12. def update(item):
  13.     """
  14.     随机梯度下降更新参数
  15.     :param item: 参数是分类错误的点
  16.     :return: nothing 无返回值
  17.     """
  18.     global w, b,history    #w, b, history声明为全局变量
  19.     w[0] += 1 * item[1] * item[0][0]    #根据误分类点更新参数,这里学习效率设为1
  20.     w[1] += 1 * item[1] * item[0][1]
  21.     b += 1 * item[1]
  22.     print w, b    #输出每次更新过后的w,b以供观察
  23.     history.append([copy.copy(w), b])    #将每次更新过后的w,b记录在history数组中
  24. def cal(item):
  25.     """
  26.     计算item到超平面的距离,输出yi(w*xi+b)
  27.     (我们要根据这个结果来判断一个点是否被分类错了。如果yi(w*xi+b)>0,则分类错了)
  28.     :param item:
  29.     :return:
  30.     """
  31.     res = 0
  32.     for i in range(len(item[0])):   #迭代item的每个坐标,对于本文数据则有两个坐标x1x2
  33.         res += item[0][i] * w[i]
  34.     res += b
  35.     res *= item[1]    #这里是乘以公式中的yi
  36.     return res
  37. def check():
  38.     """
  39.     检查超平面是否已将样本正确分类
  40.     :return: true如果已正确分类则返回True
  41.     """
  42.     flag = False
  43.     for item in training_set:
  44.         if cal(item) <= 0:   #如果有分类错误的
  45.             flag = True     #flag设为True
  46.             update(item)     #用误分类点更新参数
  47.     if not flag:    #如果没有分类错误的点了
  48.         print "RESULT: w: " + str(w) + "b: " + str(b)    #输出达到正确结果时参数的值
  49.     return flag    #如果已正确分类则返回True,否则返回False
  50. if __name__ == "__main__":
  51.     for i in range(1000):   #迭代1000
  52.         if not check()break    #如果已正确分类,则结束迭代
  53.     #以下代码是将迭代过程可视化 
  54.     #首先建立我们想要做成动画的图像figure, 坐标轴axis,plot element
  55.     fig = plt.figure()
  56.     ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
  57.     line, = ax.plot([], [], 'g', lw=2)    #画一条线
  58.     label = ax.text([], [], '')
  59.     # initialization function: plot the background of each frame
  60.     def init():
  61.         line.set_data([], [])
  62.         x, y, x_, y_ = [], [], [], []
  63.         for p in training_set:
  64.             if p[1] > 0:
  65.                 x.append(p[0][0])    #存放yi=1的点的x1坐标
  66.                 y.append(p[0][1])    #存放yi=1的点的x2坐标
  67.             else:
  68.                 x_.append(p[0][0])    #存放yi=-1的点的x1坐标
  69.                 y_.append(p[0][1])    #存放yi=-1的点的x2坐标
  70.         plt.plot(x, y, 'bo', x_, y_, 'rx') #在图里yi=1的点用点表示,yi=-1的点用叉表示
  71.         plt.axis([-6, 6, -6, 6]) #横纵坐标上下限
  72.         plt.grid(True) #显示网格
  73.         plt.xlabel('x1') #这里我修改了原文表示
  74.         plt.ylabel('x2') #为了和原理中表达方式一致,横纵坐标应该是x1,x2
  75.         plt.title('Perceptron Algorithm (www.hankcs.com)') #给图一个标题:感知机算法
  76.         return line, label
  77.     # animation function.  this is called sequentially
  78.     def animate(i):
  79.         global history, ax, line, label
  80.         w = history[i][0]
  81.         b = history[i][1]
  82.         if w[1] == 0return line, label
  83.         #因为图中坐标上下限为-6~6,所以我们在横坐标为-77的两个点之间画一条线就够了,这里代码中的xi,yi其实是原理中的x1,x2
  84.         x1 = -7
  85.         y1 = -(+ w[0] * x1) / w[1]
  86.         x2 = 7
  87.         y2 = -(+ w[0] * x2) / w[1]
  88.         line.set_data([x1, x2], [y1, y2])    #设置线的两个点
  89.         x1 = 0
  90.         y1 = -(+ w[0] * x1) / w[1]
  91.         label.set_text(history[i])
  92.         label.set_position([x1, y1])
  93.         return line, label
  94.     # call the animator.  blit=true means only re-drawthe parts that have changed.
  95.     #print history    #我在这里把这行注释掉了,因为在更新参数的时候已经输出了一遍
  96.     anim = animation.FuncAnimation(fig, animate,init_func=init, frames=len(history),interval=1000, repeat=True, blit=True)
  97.     plt.show()
  98.     anim.save('perceptron.gif', fps=2, writer='imagemagick')    #这里并没有保存成功,我自己找了个软件截了图

    运行结果

    1. [1, 2] 1
    2. [-2, 1] 0
    3. [-1, 3] 1
    4. [-4, 2] 0
    5. [-3, 4] 1
    6. RESULT: w: [-3, 4] b: 1

      可视化

      这里写图片描述