利用tensorflow训练自己的图片数据(4)——神经网络训练
来源:互联网 发布:数据分析报告模板 编辑:程序博客网 时间:2024/04/29 04:29
一 . 说明
在上一篇博客——利用tensorflow训练自己的图片数据(3)中,我们建立好了本次训练的模型,接下来就是开始网络训练,并保存训练后的网络参数,以便测试时使用。
二 . 编程实现
#======================================================================#导入文件import osimport numpy as npimport tensorflow as tfimport input_dataimport model#变量声明N_CLASSES = 4 #husky,jiwawa,poodle,qiutianIMG_W = 64 # resize图像,太大的话训练时间久IMG_H = 64BATCH_SIZE =20CAPACITY = 200MAX_STEP = 200 # 一般大于10Klearning_rate = 0.0001 # 一般小于0.0001#获取批次batchtrain_dir = 'E:/Re_train/image_data/inputdata' #训练样本的读入路径logs_train_dir = 'E:/Re_train/image_data/inputdata' #logs存储路径#logs_test_dir = 'E:/Re_train/image_data/test' #logs存储路径#train, train_label = input_data.get_files(train_dir)train, train_label, val, val_label = input_data.get_files(train_dir, 0.3)#训练数据及标签train_batch,train_label_batch = input_data.get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)#测试数据及标签val_batch, val_label_batch = input_data.get_batch(val, val_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY) #训练操作定义train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES)train_loss = model.losses(train_logits, train_label_batch) train_op = model.trainning(train_loss, learning_rate)train_acc = model.evaluation(train_logits, train_label_batch)#测试操作定义test_logits = model.inference(val_batch, BATCH_SIZE, N_CLASSES)test_loss = model.losses(test_logits, val_label_batch) test_acc = model.evaluation(test_logits, val_label_batch)#这个是log汇总记录summary_op = tf.summary.merge_all() #产生一个会话sess = tf.Session() #产生一个writer来写log文件train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph) #val_writer = tf.summary.FileWriter(logs_test_dir, sess.graph) #产生一个saver来存储训练好的模型saver = tf.train.Saver()#所有节点初始化sess.run(tf.global_variables_initializer()) #队列监控coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)#进行batch的训练try: #执行MAX_STEP步的训练,一步一个batch for step in np.arange(MAX_STEP): if coord.should_stop(): break #启动以下操作节点,有个疑问,为什么train_logits在这里没有开启? _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc]) #每隔50步打印一次当前的loss以及acc,同时记录log,写入writer if step % 10 == 0: print('Step %d, train loss = %.2f, train accuracy = %.2f%%' %(step, tra_loss, tra_acc*100.0)) summary_str = sess.run(summary_op) train_writer.add_summary(summary_str, step) #每隔100步,保存一次训练好的模型 if (step + 1) == MAX_STEP: checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached')finally: coord.request_stop() #========================================================================
本次训练300次,ratio设为0.3,学习率设为0.001,批处理量为20;每10次大厅一下训练结果,并将最后训练完成时,将训练数据保存到logs_train_dir的命名为model.ckpt的文件中。
阅读全文
0 0
- 利用tensorflow训练自己的图片数据(4)——神经网络训练
- 利用tensorflow训练自己的图片数据(5)——测试训练网络
- 利用tensorflow训练自己的图片数据(2)——输入图片处理
- 利用tensorflow训练自己的图片数据(2)——输入图片处理
- 利用tensorflow训练自己的图片数据(1)——预处理
- 利用tensorflow训练自己的图片数据(3)——建立网络模型
- 利用tensorflow训练自己的图片数据(1)——预处理
- TensorFlow——训练自己的数据(一)数据处理
- TensorFlow——训练自己的数据(三)模型训练
- 利用TensorFlow Object Detection API 训练自己的数据集
- Tensorflow学习笔记:用minst数据集训练卷积神经网络并用训练后的模型测试自己的BMP图片
- TensorFlow——训练自己的数据——CIFAR10(一)数据准备
- MatConvNet卷积神经网络(四)——用自己的数据训练
- caffe利用caffenet训练自己的图片数据
- TensorFlow——训练自己的数据(二)模型设计
- TensorFlow——训练自己的数据(四)模型测试
- TensorFlow——训练自己的数据(五)模型评估
- 使用Tensorflow训练自己的分割数据
- java
- UIApplication的方法
- es6-箭头函数中的this使用
- 习题3-9 子序列(All in All, UVa 10340)
- SpringMvc的文件上传使用的时CommonsMultipartFile
- 利用tensorflow训练自己的图片数据(4)——神经网络训练
- CentOS下安装tomcat
- 检测表单是否填写,使按钮可点击
- postGresql修改主键自增脚本
- Tensorflow中使用matplotlib报错
- mac 运行npm install -g cnpm --registry=https://registry.npm.taobao.org时报错解决办法
- python
- The Triangle (dp)
- T009