机器学习通俗入门-Softmax 求解多类分类问题

来源:互联网 发布:yy免费头像源码和制作 编辑:程序博客网 时间:2024/04/30 11:17

$

问题

0-1分类问题解决将一个样本分配到A还是B的问题,总共只有两个类。而多类分类问题则包含多个类。Mnist数据集[1,2] 中包含60000张手写数字图片,10,000 张测试图片。每张图片的大小为28*28,包含一个手写数字。如图:
这里写图片描述

我们希望实现这样一个分类器: 给定一张手写图片,分类器给出改数字属于哪个分类。(0-9共10个分类)

模型

给出数据集 D={x(i),y(i)} , i[1,n] 。其中y(i) 是1到k的整数,表示该样本从属的分类,k为类的个数。

softmax函数:

δ(i)j=p(y(i)=j|x(i);w)=ewTx(j)kl=1ewTx(l)

这里的x(i) 仍然是增广向量的形式 : [1,x1,x2,x3,x4...,xk]

在给出一个样本x(i) 使用该公式就可以计算出,该样本从属于j类的概率 δ(i)j
对于一个样本,计算出从属于个各类的k个概率,选取概率最大的类作为最终分类结果。

我们令I˙ 为指示函数,即

I{}=1,I{}=0

q(i)j=I{y(i)=j}

那么举个例子,假设样本i从属于第3类,共有5类,那么

q(i)=[0,0,1,0,0]

我们希望δ(i)=[0.001,0.002,0.994,0.002,0.001] 这种形式,且他们的差越小越好。

我们利用对数最大似然估计可以得到损失函数:

J(w)=1ni=1mj=0kq(i)jlogδ(i)j

误差最小化

我们仍然使用梯度下降法最小化误差函数 J(w) ,对其求w的偏导数。得到结果:

J(w)wj=1n[x(i)(δ(i)jq(i)j)]

我们惊奇的发现,它仍然是

=()

这里的真实值 就是 [ 0 0 1 0 0 ]这种形式。

迭代公式为

w=wa˙J(w)wj

matlab实现

下载数据集

首先获取mnist数据集。下载地址
http://yann.lecun.com/exdb/mnist/
解压后得到四个文件。

数据集的读取

train-labels-idx1-ubyte 文件的格式:
前四个字节 :固定数字 0x0801
下面四个字节: 样本个数
下面每个字节代表一个样本的标签(0-9)

train-images-idx3-ubyte 的格式
int 固定数字;
int 样本个数(即图像个数)
int 图像行数
int 图像列数
逐个存储每个图片的像素,每个图片占有 行数*列数 个字节。

测试集格式类似。 不过我们不需要自己去写读数据集的方法,有一个matlab的方法可以使用。

function images = loadMNISTImages(filename)%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing%the raw MNIST imagesfp = fopen(filename, 'rb');assert(fp ~= -1, ['Could not open ', filename, '']);magic = fread(fp, 1, 'int32', 0, 'ieee-be');assert(magic == 2051, ['Bad magic number in ', filename, '']);numImages = fread(fp, 1, 'int32', 0, 'ieee-be');numRows = fread(fp, 1, 'int32', 0, 'ieee-be');numCols = fread(fp, 1, 'int32', 0, 'ieee-be');images = fread(fp, inf, 'unsigned char');images = reshape(images, numCols, numRows, numImages);images = permute(images,[2 1 3]);fclose(fp);% Reshape to #pixels x #examplesimages = reshape(images, size(images, 1) * size(images, 2), size(images, 3));% Convert to double and rescale to [0,1]images = double(images) / 255;end

该函数返回一个784 * 60000 的一个矩阵。
读取label的方法

function labels = loadMNISTLabels(filename)    %loadMNISTLabels returns a [number of MNIST images]x1 matrix containing    %the labels for the MNIST images    fp = fopen(filename, 'rb');    assert(fp ~= -1, ['Could not open ', filename, '']);    magic = fread(fp, 1, 'int32', 0, 'ieee-be');    assert(magic == 2049, ['Bad magic number in ', filename, '']);    numLabels = fread(fp, 1, 'int32', 0, 'ieee-be');    labels = fread(fp, inf, 'unsigned char');    assert(size(labels,1) == numLabels, 'Mismatch in label count');    fclose(fp);end

该函数读取数据集返回一个60000*1的数组。其值为0-9,为相应图片从属的分类编号

分类器

%write by yangtf % 共分m类function [w,cost_a] = softmax_train(x,y,m)    [k,n] = size(x); % 每一列是一个样本    x = [ones(1,n); x]; %增广 k+1 x n    w = zeros(k+1,m);  % 每一列代表一组参数, 针对每一个类有一组参数,共有m组参数。    cost_a = [];    %生成指示矩阵    q = zeros(m,n);    for i = 1:n       q(y(i)+1,i) = 1;     end    rate = 0.5;    oldcost = 1;    eps = 0.0001;    lambda = 0.001;    i = 1 ; %迭代次数    while true       s = exp( w'*x ); % s: m x (k+1) * (k+1)x n -> m * n 分类数行 * 样本数列       ss = sum(s) ; % 1:n        d = s ./ repmat(ss,m,1); % m*n 计算出每个样本对每个类的概率       td = 1/n * x *( d - q )' + lambda*w  ; %梯度 k+1 x m       w =w - rate*td;       %计算误差       cost = - 1/n * sum(sum(q.*log(d)));       cost_a = [cost_a cost];       [mm,idx] = max(d);       acc =  sum(abs(idx==(y'+1)))/n;       fprintf('iter %d ,cost = %f cost_cha = %d ,acc = %f \n',i,cost,abs(cost-oldcost),acc);       if abs(cost - oldcost) < eps            %计算准确率           fprintf('accury is %f\n',acc);           break;       end       i= i+1;       oldcost =cost;    endend
labels = loadMNISTLabels('mnist/train-labels.idx1-ubyte');images = loadMNISTImages('mnist/train-images.idx3-ubyte');w = softmax_train(images,labels,10);

实验结果

iter 1 ,cost = 2.302585 cost_cha = 1.302585e+00 ,acc = 0.098717
iter 2 ,cost = 1.831772 cost_cha = 4.708134e-01 ,acc = 0.672383
iter 3 ,cost = 1.514248 cost_cha = 3.175232e-01 ,acc = 0.777133
iter 4 ,cost = 1.302185 cost_cha = 2.120639e-01 ,acc = 0.775317
iter 5 ,cost = 1.156026 cost_cha = 1.461581e-01 ,acc = 0.806450
iter 6 ,cost = 1.049749 cost_cha = 1.062773e-01 ,acc = 0.807550
iter 7 ,cost = 0.970520 cost_cha = 7.922921e-02 ,acc = 0.821450
iter 8 ,cost = 0.908223 cost_cha = 6.229697e-02 ,acc = 0.824550
iter 9 ,cost = 0.858680 cost_cha = 4.954277e-02 ,acc = 0.831750
iter 10 ,cost = 0.817220 cost_cha = 4.145980e-02 ,acc = 0.834583
iter 11 ,cost = 0.782800 cost_cha = 3.442006e-02 ,acc = 0.839283
iter 12 ,cost = 0.753116 cost_cha = 2.968448e-02 ,acc = 0.843000
iter 13 ,cost = 0.727696 cost_cha = 2.542003e-02 ,acc = 0.845750
iter 14 ,cost = 0.705440 cost_cha = 2.225566e-02 ,acc = 0.849350
iter 15 ,cost = 0.685910 cost_cha = 1.953011e-02 ,acc = 0.851267
iter 16 ,cost = 0.668557 cost_cha = 1.735299e-02 ,acc = 0.853517
iter 17 ,cost = 0.653036 cost_cha = 1.552123e-02 ,acc = 0.855267
iter 18 ,cost = 0.639041 cost_cha = 1.399530e-02 ,acc = 0.857067
iter 19 ,cost = 0.626342 cost_cha = 1.269905e-02 ,acc = 0.858433
iter 20 ,cost = 0.614752 cost_cha = 1.158978e-02 ,acc = 0.860050
iter 21 ,cost = 0.604122 cost_cha = 1.062996e-02 ,acc = 0.861217
iter 22 ,cost = 0.594329 cost_cha = 9.792945e-03 ,acc = 0.862333
iter 23 ,cost = 0.585271 cost_cha = 9.057506e-03 ,acc = 0.863550
iter 24 ,cost = 0.576864 cost_cha = 8.407358e-03 ,acc = 0.864617
iter 25 ,cost = 0.569035 cost_cha = 7.829365e-03 ,acc = 0.865717
iter 26 ,cost = 0.561722 cost_cha = 7.312973e-03 ,acc = 0.866367
iter 27 ,cost = 0.554872 cost_cha = 6.849501e-03 ,acc = 0.867400
iter 28 ,cost = 0.548440 cost_cha = 6.431784e-03 ,acc = 0.868283
iter 29 ,cost = 0.542387 cost_cha = 6.053840e-03 ,acc = 0.869000
iter 30 ,cost = 0.536676 cost_cha = 5.710650e-03 ,acc = 0.869833
iter 31 ,cost = 0.531278 cost_cha = 5.397969e-03 ,acc = 0.870483
iter 32 ,cost = 0.526166 cost_cha = 5.112191e-03 ,acc = 0.871300
iter 33 ,cost = 0.521316 cost_cha = 4.850233e-03 ,acc = 0.871933
iter 34 ,cost = 0.516706 cost_cha = 4.609447e-03 ,acc = 0.872883
iter 35 ,cost = 0.512319 cost_cha = 4.387548e-03 ,acc = 0.873617
iter 36 ,cost = 0.508136 cost_cha = 4.182556e-03 ,acc = 0.874367
iter 37 ,cost = 0.504143 cost_cha = 3.992745e-03 ,acc = 0.874933
iter 38 ,cost = 0.500327 cost_cha = 3.816609e-03 ,acc = 0.875500
iter 39 ,cost = 0.496674 cost_cha = 3.652826e-03 ,acc = 0.875900
iter 40 ,cost = 0.493174 cost_cha = 3.500232e-03 ,acc = 0.876450
iter 41 ,cost = 0.489816 cost_cha = 3.357800e-03 ,acc = 0.877083
iter 42 ,cost = 0.486591 cost_cha = 3.224619e-03 ,acc = 0.877433
iter 43 ,cost = 0.483491 cost_cha = 3.099879e-03 ,acc = 0.878117
iter 44 ,cost = 0.480508 cost_cha = 2.982859e-03 ,acc = 0.878683
iter 45 ,cost = 0.477635 cost_cha = 2.872914e-03 ,acc = 0.879067
iter 46 ,cost = 0.474866 cost_cha = 2.769467e-03 ,acc = 0.879550
iter 47 ,cost = 0.472194 cost_cha = 2.671998e-03 ,acc = 0.880083
iter 48 ,cost = 0.469614 cost_cha = 2.580041e-03 ,acc = 0.880533
iter 49 ,cost = 0.467121 cost_cha = 2.493172e-03 ,acc = 0.880967
iter 50 ,cost = 0.464710 cost_cha = 2.411011e-03 ,acc = 0.881300
iter 51 ,cost = 0.462377 cost_cha = 2.333212e-03 ,acc = 0.881600
iter 52 ,cost = 0.460117 cost_cha = 2.259460e-03 ,acc = 0.882067
iter 53 ,cost = 0.457928 cost_cha = 2.189470e-03 ,acc = 0.882567
iter 54 ,cost = 0.455805 cost_cha = 2.122982e-03 ,acc = 0.882933
iter 55 ,cost = 0.453745 cost_cha = 2.059756e-03 ,acc = 0.883400
iter 56 ,cost = 0.451745 cost_cha = 1.999576e-03 ,acc = 0.883633
iter 57 ,cost = 0.449803 cost_cha = 1.942241e-03 ,acc = 0.883983
iter 58 ,cost = 0.447916 cost_cha = 1.887569e-03 ,acc = 0.884383
iter 59 ,cost = 0.446080 cost_cha = 1.835391e-03 ,acc = 0.884617
iter 60 ,cost = 0.444295 cost_cha = 1.785551e-03 ,acc = 0.884900
iter 61 ,cost = 0.442557 cost_cha = 1.737908e-03 ,acc = 0.885317
iter 62 ,cost = 0.440864 cost_cha = 1.692329e-03 ,acc = 0.885617
iter 63 ,cost = 0.439216 cost_cha = 1.648692e-03 ,acc = 0.885883
iter 64 ,cost = 0.437609 cost_cha = 1.606884e-03 ,acc = 0.886183
iter 65 ,cost = 0.436042 cost_cha = 1.566800e-03 ,acc = 0.886500
iter 66 ,cost = 0.434514 cost_cha = 1.528343e-03 ,acc = 0.886717
iter 67 ,cost = 0.433022 cost_cha = 1.491422e-03 ,acc = 0.887267
iter 68 ,cost = 0.431566 cost_cha = 1.455954e-03 ,acc = 0.887467
iter 69 ,cost = 0.430144 cost_cha = 1.421861e-03 ,acc = 0.887733
iter 70 ,cost = 0.428755 cost_cha = 1.389070e-03 ,acc = 0.888067
iter 71 ,cost = 0.427398 cost_cha = 1.357511e-03 ,acc = 0.888283
iter 72 ,cost = 0.426071 cost_cha = 1.327123e-03 ,acc = 0.888500
iter 73 ,cost = 0.424773 cost_cha = 1.297845e-03 ,acc = 0.888833
iter 74 ,cost = 0.423503 cost_cha = 1.269622e-03 ,acc = 0.888950
iter 75 ,cost = 0.422261 cost_cha = 1.242401e-03 ,acc = 0.889200
iter 76 ,cost = 0.421045 cost_cha = 1.216134e-03 ,acc = 0.889417
iter 77 ,cost = 0.419854 cost_cha = 1.190775e-03 ,acc = 0.889683
iter 78 ,cost = 0.418688 cost_cha = 1.166281e-03 ,acc = 0.889867
iter 79 ,cost = 0.417545 cost_cha = 1.142611e-03 ,acc = 0.889967
iter 80 ,cost = 0.416425 cost_cha = 1.119727e-03 ,acc = 0.890167
iter 81 ,cost = 0.415328 cost_cha = 1.097594e-03 ,acc = 0.890433
iter 82 ,cost = 0.414252 cost_cha = 1.076177e-03 ,acc = 0.890633
iter 83 ,cost = 0.413196 cost_cha = 1.055444e-03 ,acc = 0.890850
iter 84 ,cost = 0.412161 cost_cha = 1.035366e-03 ,acc = 0.891000
iter 85 ,cost = 0.411145 cost_cha = 1.015913e-03 ,acc = 0.891267
iter 86 ,cost = 0.410148 cost_cha = 9.970603e-04 ,acc = 0.891450
iter 87 ,cost = 0.409169 cost_cha = 9.787807e-04 ,acc = 0.891617
iter 88 ,cost = 0.408208 cost_cha = 9.610508e-04 ,acc = 0.891817
iter 89 ,cost = 0.407264 cost_cha = 9.438476e-04 ,acc = 0.892033
iter 90 ,cost = 0.406337 cost_cha = 9.271498e-04 ,acc = 0.892167
iter 91 ,cost = 0.405426 cost_cha = 9.109367e-04 ,acc = 0.892350
iter 92 ,cost = 0.404531 cost_cha = 8.951890e-04 ,acc = 0.892617
iter 93 ,cost = 0.403651 cost_cha = 8.798882e-04 ,acc = 0.892883
iter 94 ,cost = 0.402786 cost_cha = 8.650169e-04 ,acc = 0.893000
iter 95 ,cost = 0.401935 cost_cha = 8.505583e-04 ,acc = 0.893150
iter 96 ,cost = 0.401099 cost_cha = 8.364967e-04 ,acc = 0.893417
iter 97 ,cost = 0.400276 cost_cha = 8.228170e-04 ,acc = 0.893567
iter 98 ,cost = 0.399466 cost_cha = 8.095049e-04 ,acc = 0.893817
iter 99 ,cost = 0.398670 cost_cha = 7.965466e-04 ,acc = 0.894017
iter 100 ,cost = 0.397886 cost_cha = 7.839293e-04 ,acc = 0.894167
iter 101 ,cost = 0.397114 cost_cha = 7.716405e-04 ,acc = 0.894233
iter 102 ,cost = 0.396355 cost_cha = 7.596684e-04 ,acc = 0.894450
iter 103 ,cost = 0.395607 cost_cha = 7.480017e-04 ,acc = 0.894550
iter 104 ,cost = 0.394870 cost_cha = 7.366295e-04 ,acc = 0.894717
iter 105 ,cost = 0.394145 cost_cha = 7.255417e-04 ,acc = 0.894867
iter 106 ,cost = 0.393430 cost_cha = 7.147284e-04 ,acc = 0.894950
iter 107 ,cost = 0.392726 cost_cha = 7.041801e-04 ,acc = 0.895067
iter 108 ,cost = 0.392032 cost_cha = 6.938878e-04 ,acc = 0.895167
iter 109 ,cost = 0.391348 cost_cha = 6.838430e-04 ,acc = 0.895400
iter 110 ,cost = 0.390674 cost_cha = 6.740374e-04 ,acc = 0.895633
iter 111 ,cost = 0.390009 cost_cha = 6.644631e-04 ,acc = 0.895767
iter 112 ,cost = 0.389354 cost_cha = 6.551126e-04 ,acc = 0.895917
iter 113 ,cost = 0.388708 cost_cha = 6.459785e-04 ,acc = 0.896083
iter 114 ,cost = 0.388071 cost_cha = 6.370540e-04 ,acc = 0.896250
iter 115 ,cost = 0.387443 cost_cha = 6.283324e-04 ,acc = 0.896450
iter 116 ,cost = 0.386823 cost_cha = 6.198072e-04 ,acc = 0.896567
iter 117 ,cost = 0.386212 cost_cha = 6.114724e-04 ,acc = 0.896883
iter 118 ,cost = 0.385608 cost_cha = 6.033220e-04 ,acc = 0.896933
iter 119 ,cost = 0.385013 cost_cha = 5.953503e-04 ,acc = 0.897000
iter 120 ,cost = 0.384425 cost_cha = 5.875520e-04 ,acc = 0.897017
iter 121 ,cost = 0.383845 cost_cha = 5.799218e-04 ,acc = 0.897117
iter 122 ,cost = 0.383273 cost_cha = 5.724546e-04 ,acc = 0.897233
iter 123 ,cost = 0.382708 cost_cha = 5.651456e-04 ,acc = 0.897333
iter 124 ,cost = 0.382150 cost_cha = 5.579902e-04 ,acc = 0.897467
iter 125 ,cost = 0.381599 cost_cha = 5.509839e-04 ,acc = 0.897600
iter 126 ,cost = 0.381055 cost_cha = 5.441223e-04 ,acc = 0.897717
iter 127 ,cost = 0.380517 cost_cha = 5.374013e-04 ,acc = 0.897733
iter 128 ,cost = 0.379987 cost_cha = 5.308170e-04 ,acc = 0.897867
iter 129 ,cost = 0.379462 cost_cha = 5.243653e-04 ,acc = 0.898100
iter 130 ,cost = 0.378944 cost_cha = 5.180427e-04 ,acc = 0.898250
iter 131 ,cost = 0.378432 cost_cha = 5.118454e-04 ,acc = 0.898333
iter 132 ,cost = 0.377927 cost_cha = 5.057702e-04 ,acc = 0.898500
iter 133 ,cost = 0.377427 cost_cha = 4.998135e-04 ,acc = 0.898600
iter 134 ,cost = 0.376933 cost_cha = 4.939723e-04 ,acc = 0.898733
iter 135 ,cost = 0.376445 cost_cha = 4.882433e-04 ,acc = 0.898750
iter 136 ,cost = 0.375962 cost_cha = 4.826236e-04 ,acc = 0.898783
iter 137 ,cost = 0.375485 cost_cha = 4.771103e-04 ,acc = 0.898917
iter 138 ,cost = 0.375013 cost_cha = 4.717006e-04 ,acc = 0.899017
iter 139 ,cost = 0.374547 cost_cha = 4.663917e-04 ,acc = 0.899183
iter 140 ,cost = 0.374086 cost_cha = 4.611811e-04 ,acc = 0.899350
iter 141 ,cost = 0.373629 cost_cha = 4.560661e-04 ,acc = 0.899467
iter 142 ,cost = 0.373178 cost_cha = 4.510445e-04 ,acc = 0.899517
iter 143 ,cost = 0.372732 cost_cha = 4.461137e-04 ,acc = 0.899633
iter 144 ,cost = 0.372291 cost_cha = 4.412715e-04 ,acc = 0.899817
iter 145 ,cost = 0.371855 cost_cha = 4.365158e-04 ,acc = 0.899917
iter 146 ,cost = 0.371423 cost_cha = 4.318442e-04 ,acc = 0.899950
iter 147 ,cost = 0.370995 cost_cha = 4.272548e-04 ,acc = 0.899950
iter 148 ,cost = 0.370573 cost_cha = 4.227455e-04 ,acc = 0.900033
iter 149 ,cost = 0.370154 cost_cha = 4.183144e-04 ,acc = 0.900050
iter 150 ,cost = 0.369740 cost_cha = 4.139597e-04 ,acc = 0.900200
iter 151 ,cost = 0.369331 cost_cha = 4.096793e-04 ,acc = 0.900200
iter 152 ,cost = 0.368925 cost_cha = 4.054717e-04 ,acc = 0.900350
iter 153 ,cost = 0.368524 cost_cha = 4.013350e-04 ,acc = 0.900483
iter 154 ,cost = 0.368127 cost_cha = 3.972676e-04 ,acc = 0.900533
iter 155 ,cost = 0.367733 cost_cha = 3.932679e-04 ,acc = 0.900700
iter 156 ,cost = 0.367344 cost_cha = 3.893344e-04 ,acc = 0.900783
iter 157 ,cost = 0.366959 cost_cha = 3.854653e-04 ,acc = 0.900867
iter 158 ,cost = 0.366577 cost_cha = 3.816594e-04 ,acc = 0.900983
iter 159 ,cost = 0.366199 cost_cha = 3.779152e-04 ,acc = 0.901050
iter 160 ,cost = 0.365825 cost_cha = 3.742312e-04 ,acc = 0.901100
iter 161 ,cost = 0.365454 cost_cha = 3.706061e-04 ,acc = 0.901217
iter 162 ,cost = 0.365087 cost_cha = 3.670386e-04 ,acc = 0.901333
iter 163 ,cost = 0.364724 cost_cha = 3.635275e-04 ,acc = 0.901483
iter 164 ,cost = 0.364364 cost_cha = 3.600714e-04 ,acc = 0.901550
iter 165 ,cost = 0.364007 cost_cha = 3.566692e-04 ,acc = 0.901617
iter 166 ,cost = 0.363654 cost_cha = 3.533197e-04 ,acc = 0.901700
iter 167 ,cost = 0.363304 cost_cha = 3.500218e-04 ,acc = 0.901767
iter 168 ,cost = 0.362957 cost_cha = 3.467743e-04 ,acc = 0.901850
iter 169 ,cost = 0.362613 cost_cha = 3.435762e-04 ,acc = 0.901867
iter 170 ,cost = 0.362273 cost_cha = 3.404264e-04 ,acc = 0.901900
iter 171 ,cost = 0.361935 cost_cha = 3.373239e-04 ,acc = 0.901983
iter 172 ,cost = 0.361601 cost_cha = 3.342677e-04 ,acc = 0.902117
iter 173 ,cost = 0.361270 cost_cha = 3.312569e-04 ,acc = 0.902217
iter 174 ,cost = 0.360942 cost_cha = 3.282905e-04 ,acc = 0.902250
iter 175 ,cost = 0.360616 cost_cha = 3.253676e-04 ,acc = 0.902317
iter 176 ,cost = 0.360294 cost_cha = 3.224872e-04 ,acc = 0.902350
iter 177 ,cost = 0.359974 cost_cha = 3.196486e-04 ,acc = 0.902400
iter 178 ,cost = 0.359657 cost_cha = 3.168509e-04 ,acc = 0.902400
iter 179 ,cost = 0.359343 cost_cha = 3.140932e-04 ,acc = 0.902450
iter 180 ,cost = 0.359032 cost_cha = 3.113748e-04 ,acc = 0.902500
iter 181 ,cost = 0.358723 cost_cha = 3.086949e-04 ,acc = 0.902617
iter 182 ,cost = 0.358417 cost_cha = 3.060527e-04 ,acc = 0.902667
iter 183 ,cost = 0.358114 cost_cha = 3.034475e-04 ,acc = 0.902717
iter 184 ,cost = 0.357813 cost_cha = 3.008785e-04 ,acc = 0.902783
iter 185 ,cost = 0.357514 cost_cha = 2.983451e-04 ,acc = 0.902850
iter 186 ,cost = 0.357219 cost_cha = 2.958465e-04 ,acc = 0.902900
iter 187 ,cost = 0.356925 cost_cha = 2.933822e-04 ,acc = 0.902900
iter 188 ,cost = 0.356634 cost_cha = 2.909513e-04 ,acc = 0.902917
iter 189 ,cost = 0.356346 cost_cha = 2.885534e-04 ,acc = 0.903000
iter 190 ,cost = 0.356059 cost_cha = 2.861877e-04 ,acc = 0.903067
iter 191 ,cost = 0.355776 cost_cha = 2.838537e-04 ,acc = 0.903183
iter 192 ,cost = 0.355494 cost_cha = 2.815508e-04 ,acc = 0.903167
iter 193 ,cost = 0.355215 cost_cha = 2.792783e-04 ,acc = 0.903217
iter 194 ,cost = 0.354938 cost_cha = 2.770358e-04 ,acc = 0.903300
iter 195 ,cost = 0.354663 cost_cha = 2.748227e-04 ,acc = 0.903350
iter 196 ,cost = 0.354390 cost_cha = 2.726384e-04 ,acc = 0.903367
iter 197 ,cost = 0.354120 cost_cha = 2.704824e-04 ,acc = 0.903400
iter 198 ,cost = 0.353851 cost_cha = 2.683542e-04 ,acc = 0.903450
iter 199 ,cost = 0.353585 cost_cha = 2.662533e-04 ,acc = 0.903500
iter 200 ,cost = 0.353321 cost_cha = 2.641793e-04 ,acc = 0.903567
iter 201 ,cost = 0.353059 cost_cha = 2.621315e-04 ,acc = 0.903650
iter 202 ,cost = 0.352799 cost_cha = 2.601097e-04 ,acc = 0.903733
iter 203 ,cost = 0.352541 cost_cha = 2.581132e-04 ,acc = 0.903833
iter 204 ,cost = 0.352284 cost_cha = 2.561417e-04 ,acc = 0.903933
iter 205 ,cost = 0.352030 cost_cha = 2.541947e-04 ,acc = 0.903983
iter 206 ,cost = 0.351778 cost_cha = 2.522718e-04 ,acc = 0.904033
iter 207 ,cost = 0.351528 cost_cha = 2.503726e-04 ,acc = 0.904167
iter 208 ,cost = 0.351279 cost_cha = 2.484967e-04 ,acc = 0.904217
iter 209 ,cost = 0.351033 cost_cha = 2.466436e-04 ,acc = 0.904233
iter 210 ,cost = 0.350788 cost_cha = 2.448130e-04 ,acc = 0.904217
iter 211 ,cost = 0.350545 cost_cha = 2.430045e-04 ,acc = 0.904333
iter 212 ,cost = 0.350303 cost_cha = 2.412177e-04 ,acc = 0.904333
iter 213 ,cost = 0.350064 cost_cha = 2.394523e-04 ,acc = 0.904383
iter 214 ,cost = 0.349826 cost_cha = 2.377080e-04 ,acc = 0.904450
iter 215 ,cost = 0.349590 cost_cha = 2.359842e-04 ,acc = 0.904533
iter 216 ,cost = 0.349356 cost_cha = 2.342808e-04 ,acc = 0.904650
iter 217 ,cost = 0.349123 cost_cha = 2.325973e-04 ,acc = 0.904700
iter 218 ,cost = 0.348893 cost_cha = 2.309335e-04 ,acc = 0.904767
iter 219 ,cost = 0.348663 cost_cha = 2.292891e-04 ,acc = 0.904817
iter 220 ,cost = 0.348436 cost_cha = 2.276636e-04 ,acc = 0.904867
iter 221 ,cost = 0.348210 cost_cha = 2.260568e-04 ,acc = 0.904950
iter 222 ,cost = 0.347985 cost_cha = 2.244685e-04 ,acc = 0.904983
iter 223 ,cost = 0.347762 cost_cha = 2.228983e-04 ,acc = 0.905000
iter 224 ,cost = 0.347541 cost_cha = 2.213458e-04 ,acc = 0.905083
iter 225 ,cost = 0.347321 cost_cha = 2.198110e-04 ,acc = 0.905083
iter 226 ,cost = 0.347103 cost_cha = 2.182934e-04 ,acc = 0.905167
iter 227 ,cost = 0.346886 cost_cha = 2.167927e-04 ,acc = 0.905217
iter 228 ,cost = 0.346671 cost_cha = 2.153088e-04 ,acc = 0.905267
iter 229 ,cost = 0.346457 cost_cha = 2.138414e-04 ,acc = 0.905350
iter 230 ,cost = 0.346244 cost_cha = 2.123902e-04 ,acc = 0.905400
iter 231 ,cost = 0.346033 cost_cha = 2.109549e-04 ,acc = 0.905383
iter 232 ,cost = 0.345824 cost_cha = 2.095353e-04 ,acc = 0.905533
iter 233 ,cost = 0.345616 cost_cha = 2.081312e-04 ,acc = 0.905550
iter 234 ,cost = 0.345409 cost_cha = 2.067424e-04 ,acc = 0.905550
iter 235 ,cost = 0.345204 cost_cha = 2.053685e-04 ,acc = 0.905517
iter 236 ,cost = 0.345000 cost_cha = 2.040094e-04 ,acc = 0.905533
iter 237 ,cost = 0.344797 cost_cha = 2.026649e-04 ,acc = 0.905650
iter 238 ,cost = 0.344596 cost_cha = 2.013347e-04 ,acc = 0.905700
iter 239 ,cost = 0.344396 cost_cha = 2.000186e-04 ,acc = 0.905733
iter 240 ,cost = 0.344197 cost_cha = 1.987165e-04 ,acc = 0.905833
iter 241 ,cost = 0.343999 cost_cha = 1.974280e-04 ,acc = 0.905900
iter 242 ,cost = 0.343803 cost_cha = 1.961531e-04 ,acc = 0.905950
iter 243 ,cost = 0.343608 cost_cha = 1.948915e-04 ,acc = 0.905933
iter 244 ,cost = 0.343415 cost_cha = 1.936429e-04 ,acc = 0.905967
iter 245 ,cost = 0.343222 cost_cha = 1.924073e-04 ,acc = 0.906033
iter 246 ,cost = 0.343031 cost_cha = 1.911845e-04 ,acc = 0.906067
iter 247 ,cost = 0.342841 cost_cha = 1.899742e-04 ,acc = 0.906167
iter 248 ,cost = 0.342652 cost_cha = 1.887762e-04 ,acc = 0.906167
iter 249 ,cost = 0.342465 cost_cha = 1.875905e-04 ,acc = 0.906283
iter 250 ,cost = 0.342278 cost_cha = 1.864167e-04 ,acc = 0.906333
iter 251 ,cost = 0.342093 cost_cha = 1.852549e-04 ,acc = 0.906400
iter 252 ,cost = 0.341909 cost_cha = 1.841047e-04 ,acc = 0.906433
iter 253 ,cost = 0.341726 cost_cha = 1.829660e-04 ,acc = 0.906467
iter 254 ,cost = 0.341544 cost_cha = 1.818387e-04 ,acc = 0.906517
iter 255 ,cost = 0.341364 cost_cha = 1.807225e-04 ,acc = 0.906567
iter 256 ,cost = 0.341184 cost_cha = 1.796175e-04 ,acc = 0.906600
iter 257 ,cost = 0.341005 cost_cha = 1.785233e-04 ,acc = 0.906650
iter 258 ,cost = 0.340828 cost_cha = 1.774398e-04 ,acc = 0.906700
iter 259 ,cost = 0.340652 cost_cha = 1.763670e-04 ,acc = 0.906733
iter 260 ,cost = 0.340476 cost_cha = 1.753046e-04 ,acc = 0.906767
iter 261 ,cost = 0.340302 cost_cha = 1.742525e-04 ,acc = 0.906867
iter 262 ,cost = 0.340129 cost_cha = 1.732105e-04 ,acc = 0.906933
iter 263 ,cost = 0.339957 cost_cha = 1.721786e-04 ,acc = 0.906967
iter 264 ,cost = 0.339785 cost_cha = 1.711566e-04 ,acc = 0.906983
iter 265 ,cost = 0.339615 cost_cha = 1.701444e-04 ,acc = 0.906983
iter 266 ,cost = 0.339446 cost_cha = 1.691417e-04 ,acc = 0.907000
iter 267 ,cost = 0.339278 cost_cha = 1.681486e-04 ,acc = 0.907033
iter 268 ,cost = 0.339111 cost_cha = 1.671649e-04 ,acc = 0.907200
iter 269 ,cost = 0.338945 cost_cha = 1.661904e-04 ,acc = 0.907217
iter 270 ,cost = 0.338779 cost_cha = 1.652250e-04 ,acc = 0.907233
iter 271 ,cost = 0.338615 cost_cha = 1.642687e-04 ,acc = 0.907333
iter 272 ,cost = 0.338452 cost_cha = 1.633212e-04 ,acc = 0.907383
iter 273 ,cost = 0.338289 cost_cha = 1.623826e-04 ,acc = 0.907483
iter 274 ,cost = 0.338128 cost_cha = 1.614526e-04 ,acc = 0.907533
iter 275 ,cost = 0.337968 cost_cha = 1.605311e-04 ,acc = 0.907600
iter 276 ,cost = 0.337808 cost_cha = 1.596181e-04 ,acc = 0.907617
iter 277 ,cost = 0.337649 cost_cha = 1.587134e-04 ,acc = 0.907683
iter 278 ,cost = 0.337491 cost_cha = 1.578170e-04 ,acc = 0.907750
iter 279 ,cost = 0.337334 cost_cha = 1.569286e-04 ,acc = 0.907767
iter 280 ,cost = 0.337178 cost_cha = 1.560483e-04 ,acc = 0.907800
iter 281 ,cost = 0.337023 cost_cha = 1.551760e-04 ,acc = 0.907883
iter 282 ,cost = 0.336869 cost_cha = 1.543114e-04 ,acc = 0.907933
iter 283 ,cost = 0.336715 cost_cha = 1.534545e-04 ,acc = 0.907983
iter 284 ,cost = 0.336563 cost_cha = 1.526053e-04 ,acc = 0.908000
iter 285 ,cost = 0.336411 cost_cha = 1.517637e-04 ,acc = 0.908067
iter 286 ,cost = 0.336260 cost_cha = 1.509294e-04 ,acc = 0.908150
iter 287 ,cost = 0.336110 cost_cha = 1.501025e-04 ,acc = 0.908167
iter 288 ,cost = 0.335961 cost_cha = 1.492829e-04 ,acc = 0.908217
iter 289 ,cost = 0.335812 cost_cha = 1.484704e-04 ,acc = 0.908283
iter 290 ,cost = 0.335665 cost_cha = 1.476650e-04 ,acc = 0.908283
iter 291 ,cost = 0.335518 cost_cha = 1.468666e-04 ,acc = 0.908317
iter 292 ,cost = 0.335372 cost_cha = 1.460751e-04 ,acc = 0.908350
iter 293 ,cost = 0.335226 cost_cha = 1.452904e-04 ,acc = 0.908383
iter 294 ,cost = 0.335082 cost_cha = 1.445125e-04 ,acc = 0.908350
iter 295 ,cost = 0.334938 cost_cha = 1.437412e-04 ,acc = 0.908333
iter 296 ,cost = 0.334795 cost_cha = 1.429765e-04 ,acc = 0.908433
iter 297 ,cost = 0.334653 cost_cha = 1.422184e-04 ,acc = 0.908467
iter 298 ,cost = 0.334511 cost_cha = 1.414666e-04 ,acc = 0.908500
iter 299 ,cost = 0.334371 cost_cha = 1.407212e-04 ,acc = 0.908483
iter 300 ,cost = 0.334231 cost_cha = 1.399821e-04 ,acc = 0.908550
iter 301 ,cost = 0.334092 cost_cha = 1.392492e-04 ,acc = 0.908567
iter 302 ,cost = 0.333953 cost_cha = 1.385224e-04 ,acc = 0.908567
iter 303 ,cost = 0.333815 cost_cha = 1.378016e-04 ,acc = 0.908583
iter 304 ,cost = 0.333678 cost_cha = 1.370869e-04 ,acc = 0.908567
iter 305 ,cost = 0.333542 cost_cha = 1.363780e-04 ,acc = 0.908650
iter 306 ,cost = 0.333406 cost_cha = 1.356750e-04 ,acc = 0.908717
iter 307 ,cost = 0.333271 cost_cha = 1.349778e-04 ,acc = 0.908750
iter 308 ,cost = 0.333137 cost_cha = 1.342864e-04 ,acc = 0.908800
iter 309 ,cost = 0.333003 cost_cha = 1.336005e-04 ,acc = 0.908867
iter 310 ,cost = 0.332870 cost_cha = 1.329203e-04 ,acc = 0.908933
iter 311 ,cost = 0.332738 cost_cha = 1.322455e-04 ,acc = 0.908967
iter 312 ,cost = 0.332606 cost_cha = 1.315763e-04 ,acc = 0.908983
iter 313 ,cost = 0.332476 cost_cha = 1.309124e-04 ,acc = 0.908983
iter 314 ,cost = 0.332345 cost_cha = 1.302539e-04 ,acc = 0.909067
iter 315 ,cost = 0.332216 cost_cha = 1.296006e-04 ,acc = 0.909083
iter 316 ,cost = 0.332087 cost_cha = 1.289526e-04 ,acc = 0.909100
iter 317 ,cost = 0.331958 cost_cha = 1.283097e-04 ,acc = 0.909117
iter 318 ,cost = 0.331831 cost_cha = 1.276719e-04 ,acc = 0.909167
iter 319 ,cost = 0.331704 cost_cha = 1.270392e-04 ,acc = 0.909183
iter 320 ,cost = 0.331577 cost_cha = 1.264114e-04 ,acc = 0.909217
iter 321 ,cost = 0.331452 cost_cha = 1.257886e-04 ,acc = 0.909250
iter 322 ,cost = 0.331326 cost_cha = 1.251707e-04 ,acc = 0.909350
iter 323 ,cost = 0.331202 cost_cha = 1.245576e-04 ,acc = 0.909383
iter 324 ,cost = 0.331078 cost_cha = 1.239493e-04 ,acc = 0.909400
iter 325 ,cost = 0.330954 cost_cha = 1.233457e-04 ,acc = 0.909450
iter 326 ,cost = 0.330832 cost_cha = 1.227468e-04 ,acc = 0.909483
iter 327 ,cost = 0.330710 cost_cha = 1.221525e-04 ,acc = 0.909517
iter 328 ,cost = 0.330588 cost_cha = 1.215628e-04 ,acc = 0.909550
iter 329 ,cost = 0.330467 cost_cha = 1.209776e-04 ,acc = 0.909617
iter 330 ,cost = 0.330347 cost_cha = 1.203968e-04 ,acc = 0.909650
iter 331 ,cost = 0.330227 cost_cha = 1.198205e-04 ,acc = 0.909650
iter 332 ,cost = 0.330108 cost_cha = 1.192486e-04 ,acc = 0.909667
iter 333 ,cost = 0.329989 cost_cha = 1.186810e-04 ,acc = 0.909683
iter 334 ,cost = 0.329871 cost_cha = 1.181177e-04 ,acc = 0.909733
iter 335 ,cost = 0.329753 cost_cha = 1.175587e-04 ,acc = 0.909850
iter 336 ,cost = 0.329636 cost_cha = 1.170038e-04 ,acc = 0.909867
iter 337 ,cost = 0.329520 cost_cha = 1.164531e-04 ,acc = 0.909867
iter 338 ,cost = 0.329404 cost_cha = 1.159066e-04 ,acc = 0.909933
iter 339 ,cost = 0.329288 cost_cha = 1.153640e-04 ,acc = 0.909983
iter 340 ,cost = 0.329174 cost_cha = 1.148255e-04 ,acc = 0.910017
iter 341 ,cost = 0.329059 cost_cha = 1.142910e-04 ,acc = 0.910050
iter 342 ,cost = 0.328946 cost_cha = 1.137605e-04 ,acc = 0.910033
iter 343 ,cost = 0.328832 cost_cha = 1.132338e-04 ,acc = 0.910033
iter 344 ,cost = 0.328720 cost_cha = 1.127110e-04 ,acc = 0.910083
iter 345 ,cost = 0.328607 cost_cha = 1.121920e-04 ,acc = 0.910083
iter 346 ,cost = 0.328496 cost_cha = 1.116768e-04 ,acc = 0.910083
iter 347 ,cost = 0.328385 cost_cha = 1.111654e-04 ,acc = 0.910133
iter 348 ,cost = 0.328274 cost_cha = 1.106576e-04 ,acc = 0.910133
iter 349 ,cost = 0.328164 cost_cha = 1.101536e-04 ,acc = 0.910167
iter 350 ,cost = 0.328054 cost_cha = 1.096532e-04 ,acc = 0.910167
iter 351 ,cost = 0.327945 cost_cha = 1.091563e-04 ,acc = 0.910200
iter 352 ,cost = 0.327836 cost_cha = 1.086630e-04 ,acc = 0.910217
iter 353 ,cost = 0.327728 cost_cha = 1.081733e-04 ,acc = 0.910217
iter 354 ,cost = 0.327620 cost_cha = 1.076870e-04 ,acc = 0.910217
iter 355 ,cost = 0.327513 cost_cha = 1.072042e-04 ,acc = 0.910283
iter 356 ,cost = 0.327407 cost_cha = 1.067249e-04 ,acc = 0.910367
iter 357 ,cost = 0.327300 cost_cha = 1.062489e-04 ,acc = 0.910433
iter 358 ,cost = 0.327195 cost_cha = 1.057762e-04 ,acc = 0.910517
iter 359 ,cost = 0.327089 cost_cha = 1.053069e-04 ,acc = 0.910517
iter 360 ,cost = 0.326984 cost_cha = 1.048409e-04 ,acc = 0.910567
iter 361 ,cost = 0.326880 cost_cha = 1.043781e-04 ,acc = 0.910650
iter 362 ,cost = 0.326776 cost_cha = 1.039186e-04 ,acc = 0.910750
iter 363 ,cost = 0.326673 cost_cha = 1.034622e-04 ,acc = 0.910733
iter 364 ,cost = 0.326570 cost_cha = 1.030090e-04 ,acc = 0.910783
iter 365 ,cost = 0.326467 cost_cha = 1.025590e-04 ,acc = 0.910817
iter 366 ,cost = 0.326365 cost_cha = 1.021120e-04 ,acc = 0.910850
iter 367 ,cost = 0.326263 cost_cha = 1.016682e-04 ,acc = 0.910867
iter 368 ,cost = 0.326162 cost_cha = 1.012273e-04 ,acc = 0.910917
iter 369 ,cost = 0.326061 cost_cha = 1.007895e-04 ,acc = 0.910900
iter 370 ,cost = 0.325961 cost_cha = 1.003547e-04 ,acc = 0.910900
iter 371 ,cost = 0.325861 cost_cha = 9.992280e-05 ,acc = 0.910917
accury is 0.910917

这里写图片描述

数据可视化

上述算法得到w为785*10的矩阵,第一行为偏执b,剩下的784行为图像每个像素点的权值。将图像对每个类的权值分别汇出,得到了惊人的效果:

function r =  draw_w(w)    r = [];    for i = 1:10       wa = w(2:785,i);       wb = reshape(wa,28,28);       wb = wb /max(max(wb));       r = [r wb];    end    m2b=imresize(r,size(r)*8,'nearest');    imshow(m2b);end

这里写图片描述

将w为负值的也绘制出来,得到以下图像:

这里写图片描述

function r =  draw_w(w)    r = [];    for i = 1:10       wa = w(2:785,i);       wb = reshape(wa,28,28);       wb = wb + abs(min(min(wb))); %这里加了最小值的绝对值       wb = wb /max(max(wb));       r = [r wb];    end    %imshow(r);    m2b=imresize(r,size(r)*8,'nearest');    imshow(m2b)end

用红色代表正值,绿色代表负值,构造一个彩色的权值图。
这里写图片描述

    r_red = r;    r_red(r<0) = 0;    r_green = r;    r_green(r>0) = 0;    r_green = -r_green;    r_green =r_green / max(max(r_green));    r_blue = zeros(size(r));    r_rgb(:,:,1) = r_red;    r_rgb(:,:,2) =r_green;    r_rgb(:,:,3) = r_blue;

参考文献

[1] Mnist数据集 http://yann.lecun.com/exdb/mnist/
[2] Y. LeCun, L.
Bottou, Y. Bengio, and P. Haffner. “Gradient-based learning applied to document recognition.” Proceedings of the IEEE, 86(11):2278-2324, November 1998.
[3] http://deeplearning.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92

原创粉丝点击