用图片数据集训练神经网络 tensorflow
来源:互联网 发布:php cgi.exe 编辑:程序博客网 时间:2024/05/17 23:50
为了学tensorflow,网上教程看了不少,大部分是利用mnist数据集,但是大部分都是利用已经处理好的非图片形式进行训练的;而且很多人都说不要自己造轮子,直接跑别人的代码,然后修改和学习,不过我觉得其实从数据集,到训练,到测试写下来,也有不少收获的。
本篇博客主要讲的是如何将自己的图片数据集进行处理,然后搭建神经网络结构,训练数据,保存和加载模型,测试等过程,其中利用mnist图片数据集(mnist),代码(Github)。下面我分步讲解一下,跟大家一起学习。参考博客附后。
如前文所述,现在的入门教程基本是跑mnist代码,但是数据集是处理后的,那么如何处理自己的图像数据集?首先要将图片数据集制作成tfrecords格式的数据。简单的说,tfrecords是一种tensorflow方便快速读取数据的一种二进制文件,适用于大量数据的处理。下面的代码将说明如何将图像制作成tfrecords文件,该mnist图片数据集大小为(28,28,3)。
# current work dircwd = os.getcwd()# data to int64Listdef _int64_feature(value): return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))# data to floatlistdef _float_feature(value): return tf.train.Feature(float_list = tf.train.FloatList(value=[value]))# data to byteslistdef _bytes_feature(value): return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))# convert image data to tfrecordsdef generate_tfrecords(data_dir,filepath): # gen a tfrecords object write write = tf.python_io.TFRecordWriter(filepath) #print cwd for index,name in enumerate(num_classes):#print class_pathfile_dir = data_dir + name + '/'for img_name in os.listdir(file_dir): img_path = file_dir + img_name print img_path img = cv2.imread(img_path) img_raw = img.tobytes() example = tf.train.Example(features=tf.train.Features(feature={'label':_int64_feature(index),'img_raw':_bytes_feature(img_raw)})) # convert example to binary string write.write(example.SerializeToString()) write.close()
tf.python_io.TFRecordWriter返回一个writer对象用于将data_dir制作后的数据存入filepath中保存,该文件就是tfrecords文件。另外,tf.train.Example将数据处理成key-value(在这里就是标签-图像)的格式返回一个example对象。最后writer将数据写到filepath中,关闭writer,就完成了图片数据到二进制文件的制作过程。
制作完成之后,在神经网络中如何读取和解析呢?如下代码
# read and decode tfrecorddef read_and_decode_tfrecord(filename): # produce file deque filename_deque = tf.train.string_input_producer([filename]) # generate reader object reader = tf.TFRecordReader() # read data from filename_deque _, serialized_example = reader.read(filename_deque) # decode into fixed form features = tf.parse_single_example(serialized_example,features={'label':tf.FixedLenFeature([],tf.int64),'img_raw':tf.FixedLenFeature([],tf.string)}) label = tf.cast(features['label'],tf.int32) img = tf.decode_raw(features['img_raw'],tf.uint8) img = tf.reshape(img,[28,28,3]) img = tf.cast(img,tf.float32)/255.-0.5 return label,img
其中,tf.train.string_input_producer([filename])是将filename的文件内容制作成一个队列,然后tf.parse_single_example按照固定的格式将内容解析出来,稍加处理即可得到label和img,当然[filename]中可以有很多file,因为当图片数据太大时可能会将数据分成好几个部分分别制作tfrecords进行存储和读取。
然后搭建神经网络,这里就搭建一个简单点的,如下
# create network class network(object): # define parameters w and b def __init__(self):with tf.variable_scope("Weight"): self.weights={'conv1':tf.get_variable('conv1',[5,5,3,32],initializer=tf.contrib.layers.xavier_initializer_conv2d()),'conv2':tf.get_variable('conv2',[5,5,32,64],initializer=tf.contrib.layers.xavier_initializer_conv2d()),'fc1' :tf.get_variable('fc1', [7*7*64,1024],initializer=tf.contrib.layers.xavier_initializer()),'fc2' :tf.get_variable('fc2', [1024,10], initializer=tf.contrib.layers.xavier_initializer()),}with tf.variable_scope("biases"): self.biases={'conv1':tf.get_variable('conv1',[32,],initializer=tf.constant_initializer(value=0.0,dtype=tf.float32)),'conv2':tf.get_variable('conv2',[64,],initializer=tf.constant_initializer(value=0.0,dtype=tf.float32)),'fc1' :tf.get_variable('fc1', [1024,],initializer=tf.constant_initializer(value=0.0,dtype=tf.float32)),'fc2' :tf.get_variable('fc2', [10,] ,initializer=tf.constant_initializer(value=0.0,dtype=tf.float32)),} # define model def model(self,img):conv1 = tf.nn.bias_add(tf.nn.conv2d(img,self.weights['conv1'],strides=[1,1,1,1],padding='SAME'),self.biases['conv1'])relu1 = tf.nn.relu(conv1)pool1 = tf.nn.max_pool(relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')conv2 = tf.nn.bias_add(tf.nn.conv2d(pool1,self.weights['conv2'],strides=[1,1,1,1],padding='SAME'),self.biases['conv2'])relu2 = tf.nn.relu(conv2)pool2 = tf.nn.max_pool(relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')flatten = tf.reshape(pool2,[-1,self.weights['fc1'].get_shape().as_list()[0]])drop1 = tf.nn.dropout(flatten,0.8)fc1 = tf.matmul(drop1,self.weights['fc1']) + self.biases['fc1']fc_relu1 = tf.nn.relu(fc1)fc2 = tf.matmul(fc_relu1,self.weights['fc2'])+self.biases['fc2']return fc2 # define model test def test(self,img):img = tf.reshape(img,shape=[-1,28,28,3])conv1 = tf.nn.bias_add(tf.nn.conv2d(img,self.weights['conv1'],strides=[1,1,1,1],padding='SAME'),self.biases['conv1'])relu1 = tf.nn.relu(conv1)pool1 = tf.nn.max_pool(relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')conv2 = tf.nn.bias_add(tf.nn.conv2d(pool1,self.weights['conv2'],strides=[1,1,1,1],padding='SAME'),self.biases['conv2'])relu2 = tf.nn.relu(conv2)pool2 = tf.nn.max_pool(relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')flatten = tf.reshape(pool2,[-1,self.weights['fc1'].get_shape().as_list()[0]])drop1 = tf.nn.dropout(flatten,1)fc1 = tf.matmul(drop1,self.weights['fc1']) + self.biases['fc1']fc_relu1 = tf.nn.relu(fc1)fc2 = tf.matmul(fc_relu1,self.weights['fc2'])+self.biases['fc2']return fc2 #loss def softmax_loss(self,predicts,labels):predicts = tf.nn.softmax(predicts)labels = tf.one_hot(labels,len(num_classes))loss = -tf.reduce_mean(labels*tf.log(predicts))self.cost = lossreturn self.cost # optimizer def optimizer(self,loss,lr=0.001):train_optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)return train_optimizer
tf.contrib.layers.xavier_initializer_conv2d()是对参数进行初始化的,据某位童鞋的博客说,当激活函数是sigmoid或tanh时,这个初始化方法比较好,但是当激活函数是relu时,使用tf.contrib.layers.variance_scaling_initializer比较好,具体我没有尝试过,大家可以试一试,另外tf.contrib.layers.xavier_initializer()也是一种权值初始化方式。而在神经网络中,权值的初始化非常重要,可以按照某种特定的分布来初始化,以后可以尝试使用其他初始化方式从而加快收敛速度和准确率。
dropout层是一种解决过拟合的方法,它是在训练的过程中对网络中的某些神经单元按照一定概率屏蔽,此时的概率选择为0.8,但是在测试时可以不用dropout层,或者将概率设置为1,即使用所有的神经单元,这样应该能提高正确率。
模型搭建完毕就可以开始训练了,代码如下
def train(): label, img = read_and_decode_tfrecord(train_tfrecords_dir) img_batch,label_batch = tf.train.shuffle_batch([img,label],num_threads=16,batch_size=batch_size,capacity=50000,min_after_dequeue=49000) net = network() predicts = net.model(img_batch) loss = net.softmax_loss(predicts,label_batch) opti = net.optimizer(loss) # add trace tf.summary.scalar('cost fuction',loss) merged_summary_op = tf.summary.merge_all() train_correct = tf.equal(tf.cast(tf.argmax(predicts,1),tf.int32),label_batch) train_accuracy = tf.reduce_mean(tf.cast(train_correct,tf.float32)) #evaluate test_label,test_img = read_and_decode_tfrecord(test_tfrecords_dir) test_img_batch,test_label_batch = tf.train.shuffle_batch([test_img,test_label],num_threads=16,batch_size=batch_size,capacity=50000,min_after_dequeue=40000) test_out = net.test(test_img_batch) test_correct = tf.equal(tf.cast(tf.argmax(test_out,1),tf.int32),test_label_batch) test_accuracy = tf.reduce_mean(tf.cast(test_correct,tf.float32)) # init varibels init = tf.global_variables_initializer() with tf.Session() as sess:sess.run(init)# manerge different threadscoord = tf.train.Coordinator() summary_writer = tf.summary.FileWriter('log',sess.graph) # run deque threads = tf.train.start_queue_runners(sess=sess,coord=coord)path =cwd+'/'+'model/model.ckpt'#if os.path.exists(path):try: print "try to reload model ......" tf.train.Saver(max_to_keep=None).restore(sess,path) print 'reload successful ......'except: print 'reload model failed ......'finally: print 'training .......' for i in range(1,epoch+1): #val,l= sess.run([img_batch,label_batch]) if i%50 ==0: loss_np,_,label_np,img_np,predict_np = sess.run([loss,opti,label_batch,img_batch,predicts])tr_accuracy_np = sess.run([train_accuracy])print i,' epoch loss :',loss_np,' train accuracy: ', tr_accuracy_np if i%200==0:summary_str,_l,_o = sess.run([merged_summary_op,loss,opti])summary_writer.add_summary(summary_str,i)te_accuracy = sess.run([test_accuracy])print 'test accuracy: ', te_accuracy if i%1000==0: tf.train.Saver(max_to_keep=None).save(sess,os.path.join('model','model.ckpt'))# somethind happend that the thread should stopcoord.request_stop()# wait for all threads should stop and then stopcoord.join(threads)
tf.train.shuffle_batch是将队列里的数据打乱顺序使用n_threads个线程,batch_size大小的形式读取出来,capacity是整个队列的容量,min_after_deque代表参与顺序打乱的程度,参数越大代表数据越混乱。在本代码中,由于各个类别已经分好,大概都是5000张,而在制作tfrecords的时候是按顺序存储的,所以使用tf.train.shuffle_batch来打乱顺序,但是如果batch_size设置太小,那很大概率上每个batch_size的图像数据的类别都是一样的,造成过拟合,所以本次将batch_size设置成2000,这样效果比较明显,设置成1000也可以,或者在处理数据的时候提前将数据打乱,或者有其他方法欢迎下方讨论。
tf.summary.scalar('cost fuction',loss) merged_summary_op = tf.summary.merge_all()summary_writer = tf.summary.FileWriter('log',sess.graph)summary_str,_l,_o = sess.run([merged_summary_op,loss,opti])summary_writer.add_summary(summary_str,i)
tensorboard --logdir=log
python 目录/tensorboard.py --logdir=log
关于evaluate部分,在训练过程中可以使用一部分数据集来验证模型的准确率,本程序将验证集合测试集视为相同。
coord = tf.train.Coordinator()#创建一个协调器,用于管理线程,发生错误时及时关闭线程
threads = tf.train.start_queue_runners(sess=sess,coord=coord)#各个线程开始读取数据,这一句如果没有,整个网络将被挂起
coord.request_stop()#某个线程数据读取完或发生错误请求停止
coord.join(threads)#所有线程都请求停止后关闭线程
以上几行代码的搭配是线程的开启和关闭过程,后面两句如果不存在,当读取过程出现某些错误(如Outofrange)时,程序将不会正常关闭等,详细情况大家可以查阅一下其他资料。
tf.train.Saver(max_to_keep=None).save(sess,os.path.join('model','model.ckpt'))
tf.train.Saver(max_to_keep=None).restore(sess,path)
以上两句是模型的保存和恢复,max_to_keep=None这个参数是保存最新的或者加载最新的模型。
accuracy_np = sess.run([accuracy])
以上是关于输出,如果希望得到某个输出A,那么只要使用A_out = sess.run([A])即可
参考博客:1. http://blog.csdn.net/hjimce/article/details/51899683
- 用图片数据集训练神经网络 tensorflow
- Tensorflow学习笔记:用minst数据集训练卷积神经网络并用训练后的模型测试自己的BMP图片
- tensorflow-mnist数据集训练
- 基于Tensorflow, OpenCV. 使用MNIST数据集训练卷积神经网络模型,用于手写数字识别
- TensorFlow下用自己的数据集训练Faster RCNN
- 使用IRIS数据集训练第一个深度神经网络
- MNIST数据集训练
- cifar数据集训练
- 机器学习笔记6:TensorFlow入门之MNIST数据集训练
- Caffe mnist数据集训练
- Faster-RCNN+VGG用自己的数据集训练模型
- 用ImageNet的数据集训练Faster R-CNN
- 用ImageNet的数据集训练Faster R-CNN
- Faster-RCNN+VGG用自己的数据集训练模型
- YOLO(v1)用自己的数据集训练模型
- 01-Keras之用MNIST数据集训练一个DNN
- 03-Keras之用MNIST数据集训练一个CNN
- ChainerCV下用自己的数据集训练Faster RCNN
- USART RX 不上拉的后果
- Linux环境PHP7.0安装
- Kafka消费组(consumer group)
- windows 环境安装 minicpan (perl 的本地库,适合无互联网环境安装新的perl模块)
- 关于cookie与storage的一些理解
- 用图片数据集训练神经网络 tensorflow
- vue组件之间的通信
- Could not publish server configuration for Tomcat v7.0 Server at localhost. Multiple Contexts have a
- ccf 交通规划
- 机器牛耕地
- Spring集成RabbitMQ
- cocoscreator 去掉默认loading 默认logo progress
- 校园广播流程(硬件接口、采集卡设置、)
- 快速理解JAVA虚拟机一些要点