机器学习通俗入门-使用梯度下降法求解二分问题
来源:互联网 发布:isodata聚类算法 java 编辑:程序博客网 时间:2024/05/17 08:42
回顾
在前面的文章中介绍了使用梯度下降法解决回归问题。那么使用如何解决二分问题呢?
问题
现在有这么一个数据集
数据集
我们使用matlab造一个数据集出来。
% 创建测试数据x1 = [normrnd(3,1,40,1) normrnd(3,2,40,1)];x2 = [normrnd(7,1,40,1) normrnd(6,2,40,1)];hold off;plot(x1(:,1),x1(:,2),'or');hold on;plot(x2(:,1),x2(:,2),'ob');y1 = zeros(40,1);y2 = ones(40,1);x = [x1;x2]; y = [y1;y2];hold off;
数据的值为:
D = 1.5764 1.8386 0 1.1573 4.6626 0 3.1160 3.1710 0 3.3631 2.4825 0 3.1122 3.9442 0 2.3633 1.1272 0 2.4668 4.7144 0 1.2313 2.2046 0 4.0816 3.4433 0 3.6820 6.8999 0 3.3439 2.3597 0 2.8488 5.1085 0 4.0656 3.4616 0 2.2035 7.1164 0 3.8557 4.4835 0 4.3226 2.2885 0 6.0615 0.7585 0 2.4069 4.0333 0 2.9405 0.0246 0 3.6190 -0.3831 0 3.4319 5.0499 0 4.7547 0.5282 0 1.0382 2.2332 0 3.7447 4.8796 0 3.9460 -0.3395 0 4.8710 -0.9450 0 4.3029 3.2495 0 2.0460 3.8779 0 2.2227 -0.1015 0 2.0860 2.4868 0 3.1852 5.5069 0 2.5845 1.6710 0 2.5240 1.0334 0 0.8195 4.4851 0 3.5216 3.3754 0 4.1698 2.6556 0 4.9099 1.5497 0 3.7109 6.1300 0 3.7440 0.5363 0 2.5802 -0.9823 0 6.6795 7.8466 1.0000 6.9358 5.0137 1.0000 6.2642 8.0141 1.0000 8.8116 4.0538 1.0000 7.4373 7.0817 1.0000 6.0725 4.4535 1.0000 5.1592 2.6908 1.0000 6.8279 5.9248 1.0000 8.1433 5.4880 1.0000 6.8186 5.3189 1.0000 8.0123 6.7615 1.0000 6.9431 6.6514 1.0000 6.4406 4.6153 1.0000 6.6724 6.3087 1.0000 6.9785 3.4098 1.0000 6.5140 7.9431 1.0000 7.6444 2.4133 1.0000 8.2008 3.7157 1.0000 7.4970 3.8444 1.0000 5.9602 5.5373 1.0000 7.4493 6.2110 1.0000 6.6591 6.8626 1.0000 6.7736 8.0535 1.0000 5.1214 6.4982 1.0000 7.6140 2.6426 1.0000 7.2352 5.1972 1.0000 7.9788 7.0529 1.0000 7.4953 10.8854 1.0000 6.0268 5.5639 1.0000 7.3062 6.3531 1.0000 7.9649 5.8976 1.0000 6.5187 8.1833 1.0000 6.7009 7.4581 1.0000 4.7819 6.3288 1.0000 5.7821 5.8203 1.0000 7.6720 7.9396 1.0000 9.4654 1.2340 1.0000 6.1818 9.1185 1.0000 7.9709 8.0687 1.0000 8.7071 7.0849 1.0000
为了能直观表示这些数据,我们根据y值使用不同颜色将其绘出。
将x,y,z画为立体图,效果如下:
模型
在上一篇文章中,我们使用了
使用matlab绘制出该函数曲线:
>> x = -10:0.1:10;>> y = 1 ./ (1+ exp(-x));>> plot(x,y)
这个函数叫做Logistic函数,或Sigmoid函数。可以看到他在x趋向于无穷小时,逼近0 ,在x趋向于无穷大时逼近与0 。它的因变量取值为(0,1) ,与概率的取值范围相同。在自变量靠近0点的时候,y值的变化比较陡峭,这样它就会对x的变化敏感。
我们让上述公式中的
其中w为参数列向量, x为样本列向量。
也许有人问,
参数估计
给出上述数据集,估计出
最大似然估计是,给定一组样本,找到一个参数使得使用该参数时,样本出现的概率最大。
使用
我们令
最大似然函数
这个概率越大越好,那么设损失函数为它的相反数。
为了找到让损失函数最小的w,我们仍然使用梯度下降法。具体做法是,首先对
求导
为了对这个公式求导,我们先对f(z)和z(x)进行求导。
可以看到,
对
这个公式是如此的熟悉,和线性回归的公式很相似。线性回归中
,这里
(估计值 - 真实值)*自变量,然后所有样本的这个值求和
迭代公式为:
其中f-y表示估计值和真实值的差,这个差越大说明w需要调整的越多,它也参数需要调整的量成正比。x越大需要调整的越多。
python实现
http://blog.csdn.net/taiji1985/article/details/51250860
matlab实现
function [w,f,c,accury] = lr_predict(x,y) [n,k] = size(x); % n为样本数 ,k为维度 %增广x x = [ones(n,1) x ]; %随机生成w初值 w = rand(1,k+1); % 弄一个横向量方便 求 wx olde = 0; e = 1; eps = 0.0001 ; rate = 0.01; i = 0; %计数器 while true z = w*x'; f = 1./(1+exp(-z)); e = sum(abs(f-y'))/n; %误差 ' w = w - rate*(f-y')*x; % 更新权值 d = abs(olde -e); %计算两次误差的变化 fprintf('%d iter e = %f , d = %f \n',i,e,d); if d < eps break; end olde =e; i= i+1; end c = f>0.5; accury = (n - sum(abs(c-y')))/n; %准确率 fprintf('accury is %f ',accury);
% 创建测试数据seed = 333;rand('seed',seed)x1 = [normrnd(3,1,40,1) normrnd(3,2,40,1)];x2 = [normrnd(7,1,40,1) normrnd(6,2,40,1)];fig_on = 1;if fig_on hold off; plot(x1(:,1),x1(:,2),'or'); hold on; plot(x2(:,1),x2(:,2),'ob');endy1 = zeros(40,1);y2 = ones(40,1);x = [x1;x2]; y = [y1;y2];if fig_on hold off; figure(2); plot3(x1(:,1),x1(:,2),y1,'or'); hold on; plot3(x2(:,1),x2(:,2),y2,'ob');end%surf(x(:,1),x(:,2),y);% 进行分类[w,f,c,a] = lr_predict(x,y)if fig_on hold off; figure(3); plot(x1(:,1),x1(:,2),'or'); hold on; plot(x2(:,1),x2(:,2),'ob'); xe = x(find(c ~= y'),:) plot(xe(:,1),xe(:,2),'sm','MarkerSize',10,'LineWidth',2); %mm = min(x); %mx = max(x); %xx = mm(1):0.1:mx(1); %yy = (w(1)+w(2)*xx)/w(3); %plot(xx,yy);end
实验结果
紫色为错分。
matlab输出结果为
0 iter e = 0.476728 , d = 0.476728 1 iter e = 0.527496 , d = 0.050768 2 iter e = 0.499542 , d = 0.027954 3 iter e = 0.485917 , d = 0.013625 4 iter e = 0.543169 , d = 0.057251 5 iter e = 0.499576 , d = 0.043593 6 iter e = 0.488516 , d = 0.011060 7 iter e = 0.553108 , d = 0.064592 8 iter e = 0.499505 , d = 0.053603 9 iter e = 0.487229 , d = 0.012276 10 iter e = 0.544575 , d = 0.057346 11 iter e = 0.499370 , d = 0.045205 12 iter e = 0.484179 , d = 0.015191 13 iter e = 0.532477 , d = 0.048298 14 iter e = 0.499217 , d = 0.033260 15 iter e = 0.481059 , d = 0.018158 16 iter e = 0.520977 , d = 0.039918 17 iter e = 0.498982 , d = 0.021995 18 iter e = 0.476170 , d = 0.022812 19 iter e = 0.510895 , d = 0.034724 20 iter e = 0.498714 , d = 0.012181 21 iter e = 0.471163 , d = 0.027551 22 iter e = 0.499601 , d = 0.028438 23 iter e = 0.498290 , d = 0.001311 24 iter e = 0.463343 , d = 0.034947 25 iter e = 0.492636 , d = 0.029293 26 iter e = 0.497855 , d = 0.005219 27 iter e = 0.456635 , d = 0.041220 28 iter e = 0.479775 , d = 0.023140 29 iter e = 0.497030 , d = 0.017256 30 iter e = 0.443815 , d = 0.053215 31 iter e = 0.478942 , d = 0.035127 32 iter e = 0.496471 , d = 0.017529 33 iter e = 0.438143 , d = 0.058328 34 iter e = 0.456484 , d = 0.018341 35 iter e = 0.494499 , d = 0.038015 36 iter e = 0.412978 , d = 0.081521 37 iter e = 0.473886 , d = 0.060908 38 iter e = 0.494617 , d = 0.020731 39 iter e = 0.420671 , d = 0.073946 40 iter e = 0.412639 , d = 0.008031 41 iter e = 0.487323 , d = 0.074684 42 iter e = 0.345426 , d = 0.141897 43 iter e = 0.482066 , d = 0.136640 44 iter e = 0.492669 , d = 0.010603 45 iter e = 0.409059 , d = 0.083610 46 iter e = 0.329026 , d = 0.080033 47 iter e = 0.459020 , d = 0.129994 48 iter e = 0.160198 , d = 0.298822 49 iter e = 0.323109 , d = 0.162911 50 iter e = 0.453148 , d = 0.130038 51 iter e = 0.148804 , d = 0.304344 52 iter e = 0.277777 , d = 0.128973 53 iter e = 0.423044 , d = 0.145267 54 iter e = 0.133613 , d = 0.289431 55 iter e = 0.173894 , d = 0.040281 56 iter e = 0.347993 , d = 0.174099 57 iter e = 0.457374 , d = 0.109381 58 iter e = 0.199874 , d = 0.257501 59 iter e = 0.378111 , d = 0.178237 60 iter e = 0.464773 , d = 0.086662 61 iter e = 0.258071 , d = 0.206702 62 iter e = 0.402756 , d = 0.144686 63 iter e = 0.468620 , d = 0.065864 64 iter e = 0.292597 , d = 0.176024 65 iter e = 0.368138 , d = 0.075541 66 iter e = 0.453801 , d = 0.085663 67 iter e = 0.229217 , d = 0.224584 68 iter e = 0.336790 , d = 0.107573 69 iter e = 0.436250 , d = 0.099459 70 iter e = 0.163286 , d = 0.272964 71 iter e = 0.232334 , d = 0.069048 72 iter e = 0.356480 , d = 0.124146 73 iter e = 0.150146 , d = 0.206334 74 iter e = 0.236759 , d = 0.086612 75 iter e = 0.299827 , d = 0.063068 76 iter e = 0.407203 , d = 0.107377 77 iter e = 0.099765 , d = 0.307438 78 iter e = 0.109024 , d = 0.009258 79 iter e = 0.137674 , d = 0.028651 80 iter e = 0.166169 , d = 0.028494 81 iter e = 0.263474 , d = 0.097305 82 iter e = 0.267471 , d = 0.003997 83 iter e = 0.377094 , d = 0.109624 84 iter e = 0.082407 , d = 0.294687 85 iter e = 0.082223 , d = 0.000184 86 iter e = 0.082569 , d = 0.000347 87 iter e = 0.082813 , d = 0.000244 88 iter e = 0.083618 , d = 0.000805 89 iter e = 0.084850 , d = 0.001232 90 iter e = 0.086432 , d = 0.001582 91 iter e = 0.090219 , d = 0.003787 92 iter e = 0.093099 , d = 0.002880 93 iter e = 0.103860 , d = 0.010761 94 iter e = 0.109086 , d = 0.005226 95 iter e = 0.139819 , d = 0.030733 96 iter e = 0.151821 , d = 0.012001 97 iter e = 0.231801 , d = 0.079980 98 iter e = 0.230153 , d = 0.001648 99 iter e = 0.335810 , d = 0.105658 100 iter e = 0.108298 , d = 0.227512 101 iter e = 0.136344 , d = 0.028046 102 iter e = 0.135687 , d = 0.000657 103 iter e = 0.196323 , d = 0.060637 104 iter e = 0.191395 , d = 0.004928 105 iter e = 0.288650 , d = 0.097254 106 iter e = 0.168363 , d = 0.120287 107 iter e = 0.253856 , d = 0.085494 108 iter e = 0.190190 , d = 0.063667 109 iter e = 0.283906 , d = 0.093716 110 iter e = 0.158495 , d = 0.125411 111 iter e = 0.234220 , d = 0.075725 112 iter e = 0.178957 , d = 0.055263 113 iter e = 0.265306 , d = 0.086349 114 iter e = 0.161863 , d = 0.103443 115 iter e = 0.236991 , d = 0.075128 116 iter e = 0.166209 , d = 0.070782 117 iter e = 0.242725 , d = 0.076516 118 iter e = 0.159634 , d = 0.083091 119 iter e = 0.230343 , d = 0.070709 120 iter e = 0.155406 , d = 0.074936 121 iter e = 0.221664 , d = 0.066258 122 iter e = 0.149961 , d = 0.071703 123 iter e = 0.210649 , d = 0.060688 124 iter e = 0.143460 , d = 0.067189 125 iter e = 0.197445 , d = 0.053984 126 iter e = 0.135411 , d = 0.062034 127 iter e = 0.181018 , d = 0.045607 128 iter e = 0.125077 , d = 0.055941 129 iter e = 0.159931 , d = 0.034854 130 iter e = 0.111675 , d = 0.048256 131 iter e = 0.133207 , d = 0.021532 132 iter e = 0.095523 , d = 0.037684 133 iter e = 0.103705 , d = 0.008182 134 iter e = 0.080177 , d = 0.023529 135 iter e = 0.080751 , d = 0.000574 136 iter e = 0.070453 , d = 0.010298 137 iter e = 0.069785 , d = 0.000668 138 iter e = 0.066274 , d = 0.003512 139 iter e = 0.066042 , d = 0.000232 140 iter e = 0.064806 , d = 0.001236 141 iter e = 0.064808 , d = 0.000002 accury is 0.975000 w = -10.0562 1.5878 0.4737f = Columns 1 through 8 0.0014 0.0028 0.0314 0.0332 0.0451 0.0035 0.0238 0.0009 Columns 9 through 16 0.1501 0.3391 0.0304 0.0520 0.1480 0.0497 0.1703 0.1283 Columns 17 through 24 0.5347 0.0155 0.0051 0.0124 0.1204 0.1101 0.0007 0.1727 Columns 25 through 32 0.0213 0.0672 0.1867 0.0081 0.0015 0.0043 0.1033 0.0065 Columns 33 through 40 0.0043 0.0015 0.0646 0.1212 0.2088 0.2685 0.0237 0.0017 Columns 41 through 48 0.9905 0.9750 0.9831 0.9980 0.9959 0.8795 0.4102 0.9809 Columns 49 through 56 0.9971 0.9740 0.9981 0.9888 0.9348 0.9797 0.9496 0.9882 Columns 57 through 64 0.9712 0.9937 0.9818 0.9124 0.9939 0.9842 0.9926 0.8119 Columns 65 through 72 0.9730 0.9858 0.9983 0.9994 0.9216 0.9928 0.9969 0.9896 Columns 73 through 80 0.9889 0.6962 0.8998 0.9982 0.9972 0.9887 0.9989 0.9995c = Columns 1 through 14 0 0 0 0 0 0 0 0 0 0 0 0 0 0 Columns 15 through 28 0 0 1 0 0 0 0 0 0 0 0 0 0 0 Columns 29 through 42 0 0 0 0 0 0 0 0 0 0 0 0 1 1 Columns 43 through 56 1 1 1 1 0 1 1 1 1 1 1 1 1 1 Columns 57 through 70 1 1 1 1 1 1 1 1 1 1 1 1 1 1 Columns 71 through 80 1 1 1 1 1 1 1 1 1 1a = 0.9750xe = 6.0615 0.7585 5.1592 2.6908
- 机器学习通俗入门-使用梯度下降法求解二分问题
- 机器学习通俗入门-使用梯度下降法解决最简单的线性回归问题
- 机器学习通俗入门-Softmax 求解多类分类问题
- 机器学习----梯度下降法
- 机器学习--梯度下降法
- 机器学习---梯度下降法
- 机器学习-梯度下降法
- 机器学习入门和批量梯度下降法
- 机器学习入门(5)--梯度下降算法
- 梯度下降法 求解回归问题
- 梯度下降法求解线性回归问题
- 《机器学习》 梯度下降
- 《机器学习》 梯度下降
- 机器学习 ~~ 梯度下降
- 机器学习 梯度下降
- 机器学习----梯度下降
- 机器学习-梯度下降
- 机器学习--梯度下降
- 初学分布式java应用
- SDUT 3924 疯狂的bLue
- java tomcat搭建
- python 基础语法(一)
- SPY
- 机器学习通俗入门-使用梯度下降法求解二分问题
- Device Tree(一):背景介绍
- 2017.6.4 入门组 NO.3——字符串
- maxscript命令
- Elementary体验之安装Elementary OS
- 求二叉树的最大深度与最大宽度
- python ConfigParser库使用和遇到的坑
- Javase知识点的整理(—)
- 基于ubuntu16.04 Hadoop的集群配置