使用逻辑回归和神经网络进行手写数字识别
来源:互联网 发布:曾康霖 知乎 编辑:程序博客网 时间:2024/05/18 00:27
斯坦福大学机器学习课程第三次编程作业,使用多分类的逻辑回归模型和神经网络模型进行手写数字识别。
关于数据集,手写数字识别是神经网络学习的入门级经典问题MNIST上有大量开源的手写数字数据集,可以用来训练模型。
1.数据可视化
将数据集中的一部分数据用图像的形式显示出来,便于理解和应用。
displayData函数的代码为:
function [h, display_array] = displayData(X, example_width)if ~exist('example_width', 'var') || isempty(example_width) example_width = round(sqrt(size(X, 2)));endcolormap(gray);[m n] = size(X);example_height = (n / example_width);display_rows = floor(sqrt(m));display_cols = ceil(m / display_rows);pad = 1;display_array = - ones(pad + display_rows * (example_height + pad), ... pad + display_cols * (example_width + pad));curr_ex = 1;for j = 1:display_rowsfor i = 1:display_colsif curr_ex > m, break; endmax_val = max(abs(X(curr_ex, :)));display_array(pad + (j - 1) * (example_height + pad) + (1:example_height), ... pad + (i - 1) * (example_width + pad) + (1:example_width)) = ...reshape(X(curr_ex, :), example_height, example_width) / max_val;curr_ex = curr_ex + 1;endif curr_ex > m, break; endendh = imagesc(display_array, [-1 1]);axis image offdrawnow;end在ex3.m中载入数据并调用函数,显示图像结果为:
逻辑回归模型的函数和代码如下:
function [J, grad] = lrCostFunction(theta, X, y, lambda)m = length(y); % number of training examplesJ = 0;grad = zeros(size(theta));temp=[0;theta(2:end)];J= -1 * sum(y .* log(sigmoid(X*theta)) + (1-y) .*log((1-sigmoid(X*theta))))/m + lambda/(2*m)*temp'*temp;grad = (X' * (sigmoid(X*theta)-y))/m + lambda/m*temp;grad = grad(:);endfunction [all_theta] = oneVsAll(X, y, num_labels, lambda)m = size(X, 1);n = size(X, 2);all_theta = zeros(num_labels, n + 1);X = [ones(m, 1) X];options = optimset('GradObj','on','MaxIter',50);initial_theta = zeros(n+1,1);for c =1:num_labels all_theta(c,:)=fmincg(@(t)(lrCostFunction(t,X,(y==c),lambda)),... initial_theta,options);endendfunction p = predictOneVsAll(all_theta, X)m = size(X, 1);num_labels = size(all_theta, 1);p = zeros(size(X, 1), 1);X = [ones(m, 1) X];[a,p]=max(sigmoid(X*all_theta'),[],2);end此模型对训练数据的准确度达到94.82%
2.神经网络
ex3_nn.m脚本代码:
clear ; close all; clcinput_layer_size = 400; % 20x20 Input Images of Digitshidden_layer_size = 25; % 25 hidden unitsnum_labels = 10; % 10 labels, from 1 to 10 % (note that we have mapped "0" to label 10)%% =========== Part 1: Loading and Visualizing Data =============fprintf('Loading and Visualizing Data ...\n')load('ex3data1.mat');m = size(X, 1);sel = randperm(size(X, 1));sel = sel(1:100);displayData(X(sel, :));fprintf('Program paused. Press enter to continue.\n');pause;%% ================ Part 2: Loading Pameters ================fprintf('\nLoading Saved Neural Network Parameters ...\n')load('ex3weights.mat');%% ================= Part 3: Implement Predict =================pred = predict(Theta1, Theta2, X);fprintf('\nTraining Set Accuracy: %f\n', mean(double(pred == y)) * 100);fprintf('Program paused. Press enter to continue.\n');pause;rp = randperm(m);for i = 1:m fprintf('\nDisplaying Example Image\n'); displayData(X(rp(i), :)); pred = predict(Theta1, Theta2, X(rp(i),:)); fprintf('\nNeural Network Prediction: %d (digit %d)\n', pred, mod(pred, 10)); s = input('Paused - press enter to continue, q to exit:','s'); if s == 'q' break endend
其中核心函数predict代码:
function p = predict(Theta1, Theta2, X)m = size(X, 1);num_labels = size(Theta2, 1);p = zeros(size(X, 1), 1);X = [ones(m,1) X];a2 = sigmoid(X*Theta1');a2 = [ones(m,1) a2];a3 = sigmoid(a2 * Theta2');[aa,p] =max(a3,[],2);end这样便能对手写数字进行识别,其对训练数据的预测准确率达到97.52%,例如对如下图,能识别出图片上数字为6.
阅读全文
1 0
- 使用逻辑回归和神经网络进行手写数字识别
- 利用tensorflow一步一步实现基于MNIST 数据集进行手写数字识别的神经网络,逻辑回归
- 逻辑回归softmax神经网络实现手写数字识别(cs)
- 使用神经网络识别手写数字
- [DL]2.使用Softmax回归进行手写数字识别
- 使用神经网络进行逻辑回归
- 使用神经网络识别手写数字--原理部分
- Tensorflow - Tutorial (2) : 利用softmax回归进行手写数字识别
- C++使用matlab卷积神经网络库MatConvNet来进行手写数字识别
- keras入门 利用卷积神经网络进行手写数字识别
- 利用神经网络识别手写数字
- 初识神经网络--识别手写数字
- 各种机器学习方法(线性回归、支持向量机、决策树、朴素贝叶斯、KNN算法、逻辑回归)实现手写数字识别并用准确率、召回率、F1进行评估
- 神经网络:简单手写数字识别神经网络
- 第1章使用神经网络识别手写数字
- 深度学习四:tensorflow-使用卷积神经网络识别手写数字
- 使用tensorflow卷积神经网络实现mnist手写数字识别
- 使用TensorFlow重构神经网络的识别手写数字
- MySQL数据类型和运算符
- Nginx负载均衡
- 二进制求和
- sum-root-to-leaf-numbers
- 定制Debian系统支持Mac或win机型SOP
- 使用逻辑回归和神经网络进行手写数字识别
- Java之身份证号验证
- H5调用相机,裁剪,压缩照片
- 凸包算法合集
- 1123: 最佳校友
- 关于图像质量评测的分析
- 函数对象
- 水波纹效果
- input输入框是只能输入数字