感知机 @ Python

来源:互联网 发布:有道英语翻译软件下载 编辑:程序博客网 时间:2024/06/06 09:38

感知机(二分类问题) @ Python

  • M 是 误分类点的集合
  • 损失函数 :
    • minw,bL(w,b)=xiMyi(wxi+b)
  • 损失函数的梯度 :
    • wL(w,b)=xiMyixi
    • bL(w,b)=xiMyi
  • 采用随机梯度下降法, 随机选取一个误分类点对w, b进行更新
    • w=w(αyixi)
    • b=b(αyi)
# _*_ coding:utf-8 _*_import numpy as npimport matplotlib.pyplot as pltclass Perceptron:    def __init__(self, x, y=1):        self.x = x        self.y = y        self.w = np.ones((self.x.shape[1], 1)) / 10.0  # 提取x的第2维度的大小, 生成一个n X 1 的0.1矩阵        self.b = 0.0  # 偏置项        self.a = 1  # 改变此处可以得到不同的平面    def train(self):        length = self.x.shape[0]        while True:            count = 0  # 记录误分类点的数目            for i in range(length):                y = np.dot(self.x[i], self.w) + self.b                # 如果是误分类点, 0 恰好在平面上                if y * self.y[i] <= 0:                    self.w = self.w + (self.a * self.y[i] * self.x[i]).reshape(self.w.shape)                    self.b = self.b + self.a * self.y[i]                    count += 1            if count == 0:                return self.w, self.bclass ShowPicture:    def __init__(self, x, y, w, b):        self.b = b        self.w = w        plt.figure(1)        plt.title('what the fuck', size=14)        plt.xlabel('x-axis', size=14)        plt.ylabel('y-axis', size=14)        xData = np.linspace(0, 5, 100)  # 创建等差数组        yData = self.expression(xData)        plt.plot(xData, yData, color='r', label='y1 data')        # 绘制散点图        for i in range(x.shape[0]):            if y[i] < 0:                plt.scatter(x[i][0], x[i][1], marker='x', s=50)            else:                plt.scatter(x[i][0], x[i][1], s=50)        plt.savefig('2d.png', dpi=75)    def expression(self, x):        y = (-self.b - self.w[0] * x) / self.w[1]        return y    def show(self):        plt.show()xArray = np.array([[3, 3], [4, 3], [1, 1]])yArray = np.array([1, 1, -1])# [[3 3]# [4 3]# [1 1]]p = Perceptron(x=xArray, y=yArray)w, b = p.train()s = ShowPicture(x=xArray, y=yArray, w=w, b=b)s.show()

改变学习率得到不同的结果

原创粉丝点击