感知机介绍及实现

来源:互联网 发布:淘宝教育官网 编辑:程序博客网 时间:2024/06/06 12:29

http://blog.csdn.net/fengbingchun/article/details/50097723


感知机介绍及实现

 2104人阅读 评论(1) 收藏 举报
 分类:

感知机(perceptron)由Rosenblatt于1957年提出,是神经网络与支持向量机的基础。

感知机是最早被设计并被实现的人工神经网络。感知机是一种非常特殊的神经网络,它在人工神经网络的发展史上有着非常重要的地位,尽管它的能力非常有限,主要用于线性分类。

感知机还包括多层感知机,简单的线性感知机用于线性分类器,多层感知机(含有隐层的网络)可用于非线性分类器。本文中介绍的均是简单的线性感知机。



图 1

感知机工作方式

         (1)、学习阶段:修改权值和偏置,根据”已知的样本”对权值和偏置不断修改----有监督学习。当给定某个样本的输入/输出模式对时,感知机输出单元会产生一个实际输出向量,用期望输出(样本输出)与实际输出之差来修正网络连接权值和偏置。

         (2)、工作阶段:计算单元变化,由响应函数给出新输入下的输出。

         感知机学习策略

感知机学习的目标就是求得一个能够将训练数据集中正负实例完全分开的分类超平面,为了找到分类超平面,即确定感知机模型中的参数w和b,需要定义一个基于误分类的损失函数,并通过将损失函数最小化来求w和b。

         (1)、数据集线性可分性:在二维平面中,可以用一条直线将+1类和-1类完美分开,那么这个样本空间就是线性可分的。因此,感知机都基于一个前提,即问题空间线性可分;

         (2)、定义损失函数,找到参数w和b,使得损失函数最小。

         损失函数的选取

         (1)、损失函数的一个自然选择就是误分类点的总数,但是这样的点不是参数w,b的连续可导函数,不易优化;

         (2)、损失函数的另一个选择就是误分类点到划分超平面S(w*x+b=0)的总距离。



以上理论部分主要来自: http://staff.ustc.edu.cn/~qiliuql/files/DM2013/2013SVM.pdf

以下代码根据上面的描述实现:

perceptron.hpp:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. #ifndef _PERCEPTRON_HPP_  
  2. #define _PERCEPTRON_HPP_  
  3.   
  4. #include <vector>  
  5.   
  6. namespace ANN {  
  7.   
  8. typedef std::vector<float> feature;  
  9. typedef int label;  
  10.   
  11. class Perceptron {  
  12. private:  
  13.     std::vector<feature> feature_set;  
  14.     std::vector<label> label_set;  
  15.     int iterates;  
  16.     float learn_rate;  
  17.     std::vector<float> weight;  
  18.     int size_weight;  
  19.     float bias;  
  20.   
  21.     void initWeight();  
  22.     float calDotProduct(const feature feature_, const std::vector<float> weight_);  
  23.     void updateWeight(const feature feature_, int label_);  
  24.   
  25. public:  
  26.     Perceptron(int iterates_, float learn_rate_, int size_weight_, float bias_);  
  27.     void getDataset(const std::vector<feature> feature_set_, const std::vector<label> label_set_);  
  28.     bool train();  
  29.     label predict(const feature feature_);  
  30. };  
  31.   
  32. }  
  33.   
  34.   
  35. #endif // _PERCEPTRON_HPP_  
perceptron.cpp:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. #include "perceptron.hpp"  
  2. #include <assert.h>  
  3. #include <time.h>  
  4. #include <iostream>  
  5.   
  6. namespace ANN {  
  7.   
  8. void Perceptron::updateWeight(const feature feature_, int label_)  
  9. {  
  10.     for (int i = 0; i < size_weight; i++) {  
  11.         weight[i] += learn_rate * feature_[i] * label_; // formula 5  
  12.     }  
  13.   
  14.     bias += learn_rate * label_; // formula 5  
  15. }  
  16.   
  17. float Perceptron::calDotProduct(const feature feature_, const std::vector<float> weight_)  
  18. {  
  19.     assert(feature_.size() == weight_.size());  
  20.     float ret = 0.;  
  21.   
  22.     for (int i = 0; i < feature_.size(); i++) {  
  23.         ret += feature_[i] * weight_[i];  
  24.     }  
  25.   
  26.     return ret;  
  27. }  
  28.   
  29. void Perceptron::initWeight()  
  30. {  
  31.     srand(time(0));  
  32.     float range = 100.0;  
  33.     for (int i = 0; i < size_weight; i++) {  
  34.         float tmp = range * rand() / (RAND_MAX + 1.0);  
  35.         weight.push_back(tmp);  
  36.     }  
  37. }  
  38.   
  39. Perceptron::Perceptron(int iterates_, float learn_rate_, int size_weight_, float bias_)  
  40. {  
  41.     iterates = iterates_;  
  42.     learn_rate = learn_rate_;  
  43.     size_weight = size_weight_;  
  44.     bias = bias_;  
  45.     weight.resize(0);  
  46.     feature_set.resize(0);  
  47.     label_set.resize(0);  
  48. }  
  49.   
  50. void Perceptron::getDataset(const std::vector<feature> feature_set_, const std::vector<label> label_set_)  
  51. {  
  52.     assert(feature_set_.size() == label_set_.size());  
  53.   
  54.     feature_set.resize(0);  
  55.     label_set.resize(0);  
  56.   
  57.     for (int i = 0; i < feature_set_.size(); i++) {  
  58.         feature_set.push_back(feature_set_[i]);  
  59.         label_set.push_back(label_set_[i]);  
  60.     }  
  61. }  
  62.   
  63. bool Perceptron::train()  
  64. {  
  65.     initWeight();  
  66.   
  67.     for (int i = 0; i < iterates; i++) {  
  68.         bool flag = true;  
  69.   
  70.         for (int j = 0; j < feature_set.size(); j++) {  
  71.             float tmp = calDotProduct(feature_set[j], weight) + bias;  
  72.             if (tmp * label_set[j] <= 0) {  
  73.                 updateWeight(feature_set[j], label_set[j]);  
  74.                 flag = false;  
  75.             }  
  76.         }  
  77.   
  78.         if (flag) {  
  79.             std::cout << "iterate: " << i << std::endl;  
  80.             std::cout << "weight: ";  
  81.             for (int m = 0; m < size_weight; m++) {  
  82.                 std::cout << weight[m] << "    ";  
  83.             }  
  84.             std::cout << std::endl;  
  85.             std::cout << "bias: " << bias << std::endl;  
  86.   
  87.             return true;  
  88.         }  
  89.     }  
  90.   
  91.     return false;  
  92. }  
  93.   
  94. label Perceptron::predict(const feature feature_)  
  95. {  
  96.     assert(feature_.size() == size_weight);  
  97.   
  98.     return calDotProduct(feature_, weight) + bias >= 0 ? 1 : -1; //formula 2  
  99. }  
  100.   
  101. }  

test_NN.cpp:

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. #include <iostream>  
  2. #include "perceptron.hpp"  
  3.   
  4. int test_perceptron();  
  5.   
  6. int main()  
  7. {  
  8.     test_perceptron();  
  9.     std::cout << "ok!" << std::endl;  
  10. }  
  11.   
  12. int test_perceptron()  
  13. {  
  14.     // prepare data  
  15.     const int len_data = 20;  
  16.     const int feature_dimension = 2;  
  17.     float data[len_data][feature_dimension] = { { 10.3, 10.7 }, { 20.1, 100.8 }, { 44.9, 8.0 }, { -2.2, 15.3 }, { -33.3, 77.7 },  
  18.     { -10.4, 111.1 }, { 99.3, -2.2 }, { 222.2, -5.5 }, { 10.1, 10.1 }, { 66.6, 30.2 },  
  19.     { 0.1, 0.2 }, { 1.2, 0.03 }, { 0.5, 4.6 }, { -22.3, -11.1 }, { -88.9, -12.3 },  
  20.     { -333.3, -444.4 }, { -111.2, 0.5 }, { -6.6, 2.9 }, { 3.3, -100.2 }, { 5.6, -88.8 } };  
  21.     int label_[len_data] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  
  22.         -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 };  
  23.   
  24.     std::vector<ANN::feature> set_feature;  
  25.     std::vector<ANN::label> set_label;  
  26.   
  27.     for (int i = 0; i < len_data; i++) {  
  28.         ANN::feature feature_single;  
  29.         for (int j = 0; j < feature_dimension; j++) {  
  30.             feature_single.push_back(data[i][j]);  
  31.         }  
  32.   
  33.         set_feature.push_back(feature_single);  
  34.         set_label.push_back(label_[i]);  
  35.   
  36.         feature_single.resize(0);  
  37.     }  
  38.   
  39.     // train  
  40.     int iterates = 1000;  
  41.     float learn_rate = 0.5;  
  42.     int size_weight = feature_dimension;  
  43.     float bias = 2.5;  
  44.     ANN::Perceptron perceptron(iterates, learn_rate, size_weight, bias);  
  45.     perceptron.getDataset(set_feature, set_label);  
  46.     bool flag = perceptron.train();  
  47.     if (flag) {  
  48.         std::cout << "data set is linearly separable" << std::endl;  
  49.     }  
  50.     else {  
  51.         std::cout << "data set is linearly inseparable" << std::endl;  
  52.         return -1;  
  53.     }  
  54.   
  55.     // predict  
  56.     ANN::feature feature1;  
  57.     feature1.push_back(636.6);  
  58.     feature1.push_back(881.8);  
  59.     std::cout << "the correct result label is 1, " << "the real result label is: " << perceptron.predict(feature1) << std::endl;  
  60.   
  61.     ANN::feature feature2;  
  62.     feature2.push_back(-26.32);  
  63.     feature2.push_back(-255.95);  
  64.     std::cout << "the correct result label is -1, " << "the real result label is: " << perceptron.predict(feature2) << std::endl;  
  65.   
  66.     return 0;  
  67. }  

运行结果如下图:


GitHub:https://github.com/fengbingchun/NN

0 0