利用tensorflow训练自己的图片数据(4)——神经网络训练

来源:互联网 发布:数据分析报告模板 编辑:程序博客网 时间:2024/04/29 04:29

一  . 说明

在上一篇博客——利用tensorflow训练自己的图片数据(3)中,我们建立好了本次训练的模型,接下来就是开始网络训练,并保存训练后的网络参数,以便测试时使用。

二 . 编程实现

#======================================================================#导入文件import osimport numpy as npimport tensorflow as tfimport input_dataimport model#变量声明N_CLASSES = 4  #husky,jiwawa,poodle,qiutianIMG_W = 64   # resize图像,太大的话训练时间久IMG_H = 64BATCH_SIZE =20CAPACITY = 200MAX_STEP = 200 # 一般大于10Klearning_rate = 0.0001 # 一般小于0.0001#获取批次batchtrain_dir = 'E:/Re_train/image_data/inputdata'   #训练样本的读入路径logs_train_dir = 'E:/Re_train/image_data/inputdata'    #logs存储路径#logs_test_dir =  'E:/Re_train/image_data/test'        #logs存储路径#train, train_label = input_data.get_files(train_dir)train, train_label, val, val_label = input_data.get_files(train_dir, 0.3)#训练数据及标签train_batch,train_label_batch = input_data.get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)#测试数据及标签val_batch, val_label_batch = input_data.get_batch(val, val_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY) #训练操作定义train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES)train_loss = model.losses(train_logits, train_label_batch)        train_op = model.trainning(train_loss, learning_rate)train_acc = model.evaluation(train_logits, train_label_batch)#测试操作定义test_logits = model.inference(val_batch, BATCH_SIZE, N_CLASSES)test_loss = model.losses(test_logits, val_label_batch)        test_acc = model.evaluation(test_logits, val_label_batch)#这个是log汇总记录summary_op = tf.summary.merge_all() #产生一个会话sess = tf.Session()  #产生一个writer来写log文件train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph) #val_writer = tf.summary.FileWriter(logs_test_dir, sess.graph) #产生一个saver来存储训练好的模型saver = tf.train.Saver()#所有节点初始化sess.run(tf.global_variables_initializer())  #队列监控coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)#进行batch的训练try:    #执行MAX_STEP步的训练,一步一个batch    for step in np.arange(MAX_STEP):        if coord.should_stop():            break        #启动以下操作节点,有个疑问,为什么train_logits在这里没有开启?        _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc])                #每隔50步打印一次当前的loss以及acc,同时记录log,写入writer           if step % 10  == 0:            print('Step %d, train loss = %.2f, train accuracy = %.2f%%' %(step, tra_loss, tra_acc*100.0))            summary_str = sess.run(summary_op)            train_writer.add_summary(summary_str, step)        #每隔100步,保存一次训练好的模型        if (step + 1) == MAX_STEP:            checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')            saver.save(sess, checkpoint_path, global_step=step)       except tf.errors.OutOfRangeError:    print('Done training -- epoch limit reached')finally:    coord.request_stop()    #========================================================================

本次训练300次,ratio设为0.3,学习率设为0.001,批处理量为20;每10次大厅一下训练结果,并将最后训练完成时,将训练数据保存到logs_train_dir的命名为model.ckpt的文件中。




阅读全文
0 0
原创粉丝点击