分类算法(3) ---- 感知机(PLA)
来源:互联网 发布:android权威指南 源码 编辑:程序博客网 时间:2024/05/17 03:53
感知机是二分类模型,输入实例的特征向量,输出实例的±类别。
梯度下降法:首先,任意选定w0、b0,然后用梯度下降法不断极小化目标函数,极小化的过程不是一次性把M中的所有误分类点梯度下降,而是一次选取一个误分类点使其梯度下降。
PLA算法流程
设置迭代次数,每次迭代,从第一个点开始,每遇到一个误分类点,就更新w
得到最后的w,与测试文本进行矩阵相乘进行预测,大于0为1,小于等于0为0.
优化:进行MAX-MIN归一化;选取较合适的步长和迭代次数。
数据集格式为:第一行为58个属性名称,最后一个是分享值(0或1),接下来n行是训练数据集,再继续来m行是预测数据集,其share值为空。需要求出预测数据集的share预测值。
<span style="font-size:14px;">#include<bits/stdc++.h>#define rate 0.5 #define line_train 27751#define line_test 11893#define item_num 58#define inf 1000000using namespace std;int line = line_train + line_test;struct TEXT{ //用于保存所有文本 vector <double>feature; int shares;}text[40000];void input(){ //文本读取string str;double data;ifstream fin("Datac_all.csv");getline(fin, str);for(int i=0;i<line;i++){ getline(fin, str);istringstream ss(str);vector<double> tmp;while (!ss.eof()){ss >> data; //忽略逗号 ss.ignore(str.size(), ',');text[i].feature.push_back(data);}if(i<line_train){if(data==0) data=-1; //将0变为-1 text[i].shares = data;}}fin.close();cout << "Read done.\n";}void normalization(){ //MAX-MIN归一化double MAX, MIN;for(int i = 0; i < item_num; i++){MAX = 0; MIN = inf;//求出Max-Minfor(int j=0;j<line;j++){if(text[j].feature[i] > MAX)MAX = text[j].feature[i];if(text[j].feature[i] < MIN)MIN = text[j].feature[i];} //归一化 for(int j = 0; j < line; j++){if((MAX - MIN)!= 0)text[j].feature[i] = (text[j].feature[i] - MIN) / (MAX - MIN);elsetext[j].feature[i] = 0;}} }double XW(vector<double> w, int i){ //矩阵相乘 double A = w[0]*1.0;for(int j = 1; j <= item_num; j++)A += (double)w[j]*text[i].feature[j-1];return A;}vector<double> new_w(vector<double> w, int index){//更新w w[0] += rate*text[index].shares;for(int i = 1; i <= item_num; i++){w[i] += text[index].shares * text[index].feature[i-1];}return w;}void print(vector<double> w){ //打印w for(int i = 0; i <= item_num; i++)cout << w[i] << " ";cout << endl << endl; }void PLA(){int DDL =100;vector<double> w(item_num+1, 1);while(DDL--){ //迭代次数 for(int i = 0; i < line_train; i++){double y = XW(w,i);if(y*text[i].shares <= 0)w = new_w(w, i); //更新误分类点 }}print(w);ofstream predict("PLA.txt");for(int i = line_train; i < line; i++){double temp = XW(w, i); //进行预测 predict << (temp>0) << endl;}}int main(){input();normalization();PLA();return 0;}</span>
PLA分类过程图:
.........
0 0
- 分类算法(3) ---- 感知机(PLA)
- 感知机PLA(perceptron)
- 感知机算法原理(PLA原理)及 Python 实现
- PLA-感知机学习算法
- 统计学习方法笔记二-----感知机算法(PLA)代码实现
- 机器学习算法(分类算法)—Rosenblatt感知机
- 分类算法之感知器学习算法PLA 和口袋算法Pocket Algorithm
- 机器学习总结2_感知机算法(PLA)
- 分类系列之感知器学习算法PLA 和 口袋算法Pocket Algorithm
- 分类系列之感知器学习算法PLA 和 口袋算法Pocket Algorithm
- 分类系列之感知器学习算法PLA 和 口袋算法Pocket Algorithm
- 分类系列之感知器学习算法PLA 和 口袋算法Pocket Algorithm
- 感知器算法(二分类问题)
- PLA分类器学习(转载)
- 机器学习算法(分类算法)—Rosenblatt感知机的对偶解法
- 线性分类模型--感知机(perceptron)
- 基于简单感知器分类算法(matlab实现)
- Python机器学习(1)-- 自己设计一个感知机(Perceptron)分类算法
- QML实现分页显示
- Extjs+SpringMvc 上传文件加进度条
- Qt Creator 设置黑色风格
- 利用xtrabackup和binlog增量恢复时提示表记录不存在案例
- Java像素级的操作
- 分类算法(3) ---- 感知机(PLA)
- 从Excel文件中导入数据到SQL Server 2012
- Mantis 问题管理系统
- CDOJ 1217 The Battle of Chibi【树状数组+dp】
- SpringMVC学习系列-后记 解决GET请求时中文乱码的问题
- 教你学会使用Git和远程代码库
- 数组中有一个数字出现的次数超过数组长度的一半,请找出这个数字。例如输入一个长度为9的数组{1,2,3,2,2,2,5,4,2}。由于数字2在数组中出现了5次,超过数组长度的一半,因此输出2。如果不存在
- oc----巧用storyboard/xib的小技巧,Preview~预览,提高效率
- Windows 7下安装使用Sublime Text 3