tensorflow 训练mnist数据
来源:互联网 发布:淘宝开店要保证金吗 编辑:程序博客网 时间:2024/05/16 08:48
tensorflow 训练mnist数据
import tensorflow as tfimport numpy as npimport osimport structimport gzipfrom tensorflow.examples.tutorials.mnist import input_dataimport test8def readata(label,image): with gzip.open(label) as flbl: magic,num=struct.unpack('>II',flbl.read(8)) lab=np.fromstring(flbl.read(),dtype=np.int8) label=np.zeros((lab.shape[0],10)) for i in range(len(lab)): label[i,lab[i]]=1.0 with gzip.open(image,'rb') as fimg: magic,num,rows,cols=struct.unpack(">IIII",fimg.read(16)) img=np.fromstring(fimg.read(),dtype=np.uint8).reshape(len(label),rows,cols) image=img.reshape(img.shape[0],img.shape[1],img.shape[2],1) return label, imagewith tf.name_scope('input'): image=tf.placeholder(tf.float32, [None, 28, 28, 1]) label=tf.placeholder(tf.float32,[None,10])with tf.name_scope('conv1'): W_con1 = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1)) b_con1=tf.Variable(tf.constant(0.1,shape=[32])) c_con1=tf.nn.conv2d(image,W_con1,strides=[1,1,1,1],padding='SAME')+b_con1 h_con1=tf.nn.relu(c_con1) m_pool2=tf.nn.max_pool(h_con1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')with tf.name_scope('conv2'): W_con2=tf.Variable(tf.truncated_normal([5,5,32,64], stddev=0.1)) b_con2=tf.Variable(tf.constant(0.1,shape=[64])) c_con2=tf.nn.conv2d(m_pool2,W_con2,strides=[1,1,1,1],padding='SAME')+b_con2 h_con2=tf.nn.relu(c_con2) m_pool2=tf.nn.max_pool(h_con2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')with tf.name_scope('fc1'): W_fc1=tf.Variable(tf.truncated_normal([7*7*64,1024], stddev=0.1)) b_fc1=tf.Variable(tf.constant(0.1,shape=[1024])) m_pool2_flat=tf.reshape(m_pool2,[-1,7*7*64]) h_fc1=tf.nn.relu(tf.matmul(m_pool2_flat,W_fc1)+b_fc1)with tf.name_scope('drop'): keep_prob=tf.placeholder(tf.float32) h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob=keep_prob)with tf.name_scope('fc2'): W_fc2=tf.Variable(tf.truncated_normal(shape=[1024,10],stddev=0.1)) b_fc2=tf.Variable(tf.constant(0.1,shape=[10])) y_con=tf.matmul(h_fc1_drop,W_fc2)+b_fc2with tf.name_scope('cross_entry'): cross_entry=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=y_con))train=tf.train.GradientDescentOptimizer(0.001).minimize(cross_entry)# correct_prediction = tf.equal(tf.argmax(y_con, 1), tf.argmax(label, 1))# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))train_label,train_image=readata('train-labels-idx1-ubyte.gz','train-images-idx3-ubyte.gz')test_label,test_image=readata('t10k-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz')train_image1=test8.Dataset(train_image,train_label)saver=tf.train.Saver()//从头开始训练# with tf.Session() as sess:# sess.run(tf.global_variables_initializer())# train_writer = tf.summary.FileWriter('./' + '/train', sess.graph)# for i in range(20000):# batchimage,batchlabel=train_image1.next_batch(60)# _,loss=sess.run([train,cross_entry],feed_dict={image:batchimage,label:batchlabel,keep_prob:1.0})# if i % 100 == 0:# # tm1=test_image[(i % 1000) * 10:((i + 1) % 1000) * 10, :, :, :]# # tn1=test_label[(i % 1000) * 10:((i + 1) % 1000) * 10, :]# # train_accuracy = accuracy.eval(feed_dict={# # image:tm1, label:tn1, keep_prob: 1.0})# # print 'sss',train_accuracy# print 'loss',loss# train_writer.close()# save_path=saver.save(sess,'/home/dms/model.ckpt')# print save_path//加载模型训练with tf.Session() as sess: saver.restore(sess,'/home/dms/model.ckpt') # train_writer = tf.summary.FileWriter('./' + '/train', sess.graph) for i in range(10000): batchimage,batchlabel=train_image1.next_batch(60) _,loss=sess.run([train,cross_entry],feed_dict={image:batchimage,label:batchlabel,keep_prob:1.0}) if i % 100 == 0: # tm1=test_image[(i % 1000) * 10:((i + 1) % 1000) * 10, :, :, :] # tn1=test_label[(i % 1000) * 10:((i + 1) % 1000) * 10, :] # train_accuracy = accuracy.eval(feed_dict={ # image:tm1, label:tn1, keep_prob: 1.0}) # print 'sss',train_accuracy print 'loss',loss # train_writer.close() # save_path=saver.save(sess,'/home/dms/model.ckpt') # print save_path
//迭代器
import numpy as npclass Dataset(object): def __init__(self,images,labels): self._images=images self._labels=labels self._num_examples=len(images) self._index_in_epoch=0 def next_batch(self,batch_size): start=self._index_in_epoch if(start+batch_size>len(self._images)): rest_num_examples = self._num_examples - start image_rest_part=self._images[start:rest_num_examples] label_rest_part=self._labels[start:rest_num_examples] start=0 self._index_in_epoch=batch_size-rest_num_examples end=self._index_in_epoch image_new_part=self._images[start:end] label_new_part=self._labels[start:end] return np.concatenate((image_rest_part,image_new_part),axis=0), np.concatenate((label_rest_part,label_new_part),axis=0) self._index_in_epoch += batch_size end=self._index_in_epoch return self._images[start:end],self._labels[start:end]
0 0
- tensorflow 训练mnist数据
- TensorFlow 训练 MNIST 数据
- TensorFlow 入门之训练 MNIST 数据
- TensorFlow 训练 MNIST 数据(二)
- TensorFlow入门之训练mnist数据集
- tensorflow实现AlexNet训练mnist数据
- Tensorflow中mnist数据使用CNN训练
- TensorFlow个人学习(训练 MNIST 数据 )
- 使用Google tensorflow 训练MNIST数据集
- Tensorflow训练mnist数据(完整版)
- TensorFlow训练mnist数据集(卷积神经网络lenet5)
- 用tensorflow实现VGG网络,训练mnist数据集
- tensorflow MNIST数据集
- tensorflow mnist训练集简单训练代码
- Caffe 训练mnist数据
- Tensorflow深度学习入门——优化训练MNIST数据和调用训练模型识别图片
- tensorflow mnist训练集input.py代码
- Tensorflow实战1:利用AlexNet训练MNIST
- Docker学习文档之三 其他相关-安全性
- mysql (mysqldump) 数据库迁移
- (LeetCode) 129. Sum Root to Leaf Numbers
- drawArc()参数问题
- 单页应用(SPA)前端javascript如何阻止按下返回键页面回退
- tensorflow 训练mnist数据
- c/c++代码 No.3 位或
- Docker学习文档之三 其他相关-Docker常用命令
- vue2 与 vue1都支持的组件tree互动api。
- mysql存储过程语法及实例
- 比较全的pom.xml maven依赖
- jq获取元素位置
- xml中Integer判空
- c/c++代码 No.4 位异或