tensorflow 训练mnist数据

来源:互联网 发布:淘宝开店要保证金吗 编辑:程序博客网 时间:2024/05/16 08:48

tensorflow 训练mnist数据

import tensorflow as tfimport numpy as npimport osimport structimport gzipfrom tensorflow.examples.tutorials.mnist import input_dataimport test8def readata(label,image):    with gzip.open(label) as flbl:        magic,num=struct.unpack('>II',flbl.read(8))        lab=np.fromstring(flbl.read(),dtype=np.int8)        label=np.zeros((lab.shape[0],10))        for i in range(len(lab)):            label[i,lab[i]]=1.0    with gzip.open(image,'rb') as fimg:        magic,num,rows,cols=struct.unpack(">IIII",fimg.read(16))        img=np.fromstring(fimg.read(),dtype=np.uint8).reshape(len(label),rows,cols)        image=img.reshape(img.shape[0],img.shape[1],img.shape[2],1)    return label, imagewith tf.name_scope('input'):    image=tf.placeholder(tf.float32, [None, 28, 28, 1])    label=tf.placeholder(tf.float32,[None,10])with tf.name_scope('conv1'):    W_con1 = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1))    b_con1=tf.Variable(tf.constant(0.1,shape=[32]))    c_con1=tf.nn.conv2d(image,W_con1,strides=[1,1,1,1],padding='SAME')+b_con1    h_con1=tf.nn.relu(c_con1)    m_pool2=tf.nn.max_pool(h_con1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')with tf.name_scope('conv2'):    W_con2=tf.Variable(tf.truncated_normal([5,5,32,64], stddev=0.1))    b_con2=tf.Variable(tf.constant(0.1,shape=[64]))    c_con2=tf.nn.conv2d(m_pool2,W_con2,strides=[1,1,1,1],padding='SAME')+b_con2    h_con2=tf.nn.relu(c_con2)    m_pool2=tf.nn.max_pool(h_con2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')with tf.name_scope('fc1'):    W_fc1=tf.Variable(tf.truncated_normal([7*7*64,1024], stddev=0.1))    b_fc1=tf.Variable(tf.constant(0.1,shape=[1024]))    m_pool2_flat=tf.reshape(m_pool2,[-1,7*7*64])    h_fc1=tf.nn.relu(tf.matmul(m_pool2_flat,W_fc1)+b_fc1)with tf.name_scope('drop'):    keep_prob=tf.placeholder(tf.float32)    h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob=keep_prob)with tf.name_scope('fc2'):    W_fc2=tf.Variable(tf.truncated_normal(shape=[1024,10],stddev=0.1))    b_fc2=tf.Variable(tf.constant(0.1,shape=[10]))    y_con=tf.matmul(h_fc1_drop,W_fc2)+b_fc2with tf.name_scope('cross_entry'):    cross_entry=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=y_con))train=tf.train.GradientDescentOptimizer(0.001).minimize(cross_entry)# correct_prediction = tf.equal(tf.argmax(y_con, 1), tf.argmax(label, 1))# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))train_label,train_image=readata('train-labels-idx1-ubyte.gz','train-images-idx3-ubyte.gz')test_label,test_image=readata('t10k-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz')train_image1=test8.Dataset(train_image,train_label)saver=tf.train.Saver()//从头开始训练# with tf.Session() as sess:#     sess.run(tf.global_variables_initializer())#     train_writer = tf.summary.FileWriter('./' + '/train', sess.graph)#     for i in range(20000):#         batchimage,batchlabel=train_image1.next_batch(60)#         _,loss=sess.run([train,cross_entry],feed_dict={image:batchimage,label:batchlabel,keep_prob:1.0})#         if i % 100 == 0:#             # tm1=test_image[(i % 1000) * 10:((i + 1) % 1000) * 10, :, :, :]#             # tn1=test_label[(i % 1000) * 10:((i + 1) % 1000) * 10, :]#             # train_accuracy = accuracy.eval(feed_dict={#             #     image:tm1, label:tn1, keep_prob: 1.0})#             # print 'sss',train_accuracy#             print 'loss',loss#     train_writer.close()#     save_path=saver.save(sess,'/home/dms/model.ckpt')#     print save_path//加载模型训练with tf.Session() as sess:    saver.restore(sess,'/home/dms/model.ckpt')    # train_writer = tf.summary.FileWriter('./' + '/train', sess.graph)    for i in range(10000):        batchimage,batchlabel=train_image1.next_batch(60)        _,loss=sess.run([train,cross_entry],feed_dict={image:batchimage,label:batchlabel,keep_prob:1.0})        if i % 100 == 0:            # tm1=test_image[(i % 1000) * 10:((i + 1) % 1000) * 10, :, :, :]            # tn1=test_label[(i % 1000) * 10:((i + 1) % 1000) * 10, :]            # train_accuracy = accuracy.eval(feed_dict={            #     image:tm1, label:tn1, keep_prob: 1.0})            # print 'sss',train_accuracy            print 'loss',loss    # train_writer.close()    # save_path=saver.save(sess,'/home/dms/model.ckpt')    # print save_path

//迭代器

import numpy as npclass Dataset(object):    def __init__(self,images,labels):        self._images=images        self._labels=labels        self._num_examples=len(images)        self._index_in_epoch=0    def next_batch(self,batch_size):        start=self._index_in_epoch        if(start+batch_size>len(self._images)):            rest_num_examples = self._num_examples - start            image_rest_part=self._images[start:rest_num_examples]            label_rest_part=self._labels[start:rest_num_examples]            start=0            self._index_in_epoch=batch_size-rest_num_examples            end=self._index_in_epoch            image_new_part=self._images[start:end]            label_new_part=self._labels[start:end]            return np.concatenate((image_rest_part,image_new_part),axis=0), np.concatenate((label_rest_part,label_new_part),axis=0)        self._index_in_epoch += batch_size        end=self._index_in_epoch        return self._images[start:end],self._labels[start:end]
0 0
原创粉丝点击