Tensorflow 实现MINIST数据集多分类问题
来源:互联网 发布:蛤文化知乎 编辑:程序博客网 时间:2024/06/11 23:48
tensorflow 入门程序 MINIST数据集
tensorflow是采用计算图的方式,先把所有的计算都用计算图描述出来,然后将定义的所有计算放到外面计算,大大提高了效率
下载minists数据集
import tensorflow.examples.tutorials.mnist.input_datamnist = read_data_sets("MNIST_data/", one_hot=True)from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrange # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
#读入minist数据集,读取二进制文件,这里为读入一张图片查看一下
import structimport numpy as npimport matplotlib.pyplot as pltdef minist_read_plot1(filename): binfile = open(filename , 'rb') buf = binfile.read() index = 0 magic, numImages , numRows , numColumns = struct.unpack_from('>IIII' , buf , index) index += struct.calcsize('>IIII') #'>IIII'是说使用大端法读取4个unsinged int32 #然后读取一个图片测试是否读取成功 im = struct.unpack_from('>784B' ,buf, index) index += struct.calcsize('>784B') im = np.array(im) im = im.reshape(28,28) ##显示图片 fig = plt.figure() plotwindow = fig.add_subplot(111) plt.imshow(im , cmap='gray') plt.show() #'>784B'的意思就是用大端法读取784个unsigned byte;显示结果为5
filename = 'train-images.idx3-ubyte'minist_read_plot1(filename)
def load_mnist(path, kind='train'): """Load MNIST data from `path`""" labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind) images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind) with open(labels_path, 'rb') as lbpath: magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8) with open(images_path, 'rb') as imgpath: magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16)) images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) return images, labels
filepath='your path'images, labels = load_mnist(filepath, kind='train')test_images, test_labels = load_mnist(filepath, kind='t10k')
#使用tensorflow训练之前,首先对数据进行处理
# 1.将图片像素值 0-255 之间的值转化为 [0,1]
# 2.对标签数据进行one-hot编码,比如1-->[1,0,0,0,0,0,0,0,0,0],不过,one-hot对于softmax没有什么多大意义,不是必要选择
#另外,由于图像数据是[60000,748]的矩阵,放弃了图像的结构信息,不利于分类,
#但是这里,只是学习tensorflow的用法,不涉及算法的优化
from sklearn import preprocessingenc = preprocessing.OneHotEncoder() a = np.array([1,3,4,6,5,2,9,0,8,7])enc.fit(a) #one-hot编码首先要学习该变量整数编码label_trans = enc.transform(labels.reshape(-1,1)).toarray()test_labels = enc.transform(test_labels.reshape(-1,1)).toarray()images_trans = images/255 #转化为[0,1]之间,防止指数溢出test_images = test_images/255
##使用tensorflow训练模型,采用算法 softmax regression
import tensorflow as tfx = tf.placeholder(tf.float32, [None, 784])#搭建模型,variable是可变的张量W = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,W) + b)#损失函数占位符,交叉熵y_ = tf.placeholder("float", [None,10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))#计算,梯度下降求解train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#初始化变量init = tf.initialize_all_variables()#现在我们可以在一个Session里面启动我们的模型,并且初始化变量:sess = tf.Session()sess.run(init)
#该循环的每个步骤中,我们都会随机抓取训练数据中的100个批处理数据点,
#然后我们用这些数据点作为参数替换之前的占位符来运行train_step
##定义取消批量数据的函数,每个样本只提取一次,从头开始直到结束,按一行作为一个样本
#要求为100的整数倍def next_batch(data_set,batch_size,from_size): if batch_size%100==0: return data_set[from_size:from_size+batch_size,:]tmp = 0for i in range(600): if tmp<=len(images_trans): batch_xs = next_batch(images_trans, 100, tmp) batch_ys = next_batch(label_trans,100,tmp) tmp += 100 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#统计训练集准确率
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
#测试集效果
print(sess.run(accuracy, feed_dict={x: test_images, y_: test_labels}))结果为:0.9031
阅读全文
0 0
- Tensorflow 实现MINIST数据集多分类问题
- tensorflow minist数据集分类笔记
- python tensorflow 使用minist数据集实现手写数字识别
- tensorflow 加载minist数据
- Tensorflow实现简单Minist
- tensorflow 学习(1) minist数据初入
- 使用tensorflow实现简单的多分类问题
- Tensorflow学习:MINIST手写体
- 深度纸质学习与实验(四)-将TensorFlow加入kubernetes完成与minist数据集初试
- 利用Python可视化MINIST数据集
- KNN(k-NearestNeighbor)识别minist数据集
- 用caffe训练minist数据集
- Tensorflow minist单层感知机
- tensorflow 分类问题
- Tensorflow LSTM分类问题
- mscaffe 训练minist数据
- Python处理MINIST数据
- tensorflow实现文本分类
- sql优化个人总结
- vue实现一个tab切换
- 关于数据库脏读、不可重复读、幻读
- python文件为什么要关闭
- python matplotlib 入门系列二:figure
- Tensorflow 实现MINIST数据集多分类问题
- PHP 函数漏洞总结
- HDU
- 是学python还是java?一张图告诉你!
- JavaWeb中关于请求乱码的讲解
- SQL语句---创建表同时添加约束
- Linux Vi 删除全部内容,删除某行到结尾,删除某段内容 的方法
- Python2.7+PyQt4 QtDesigner学习笔记系列——1:环境搭建
- 1.5 x86带宽计算