机器学习小试(7)使用TensorFlow跑通一个通用增量学习流程-根据配置文件创建全连接网络
来源:互联网 发布:现在在淘宝做动漫周边 编辑:程序博客网 时间:2024/06/07 05:37
上文中,我们设计了一个配置文件,用来定义一个全连接神经网络模型的规模、学习方法。
本文,将介绍如何通过配置文件动态产生网络结构并首次训练、存盘。
1. 根据配置文件定义神经网络
全连接神经网络的计算是一串矩阵运算,可以看下图:
对一个T层(后续代码中变量名total_layer_size,算入了输出层,不含输入层)的网络,主要有以下变量:
* 每层的节点(神经元)个数S,S[0]表示输入层的节点个数,S[1]是第一隐层,S[T-1]为最后一个隐层中神经元个数,S[T]是输出层的判决向量元素个数。
* 各层的节点(神经元)的取值都是一个向量,为 a[i], 向量的大小(元素数)为S[i],i取0~T
* 除了输入a[0]是现成的,其余各层的a均通过算式
* W[i]是一个矩阵,大小为 S[i] x S[i+1],i取0~T-1
* b[i]是一个向量,为偏置,大小为 S[i+1],i取0~T-1
* z[i]是是一个向量,即各层线性传递关系(权)计算结果:
通过配置文件读取参数,即可通过一个简单的for循环,定义网络结构:
"""Created on Sun Nov 26 15:24:50 2017gn_first_training.py@author: goldenhawking"""from __future__ import print_functionimport tensorflow as tfimport numpy as npimport configparserimport reimport matplotlib.pyplot as mpltrainning_task_file = 'train_task.cfg'trainning_input_file = 'train_input.txt'model_path = './saved_model/'#读取配置config = configparser.ConfigParser()config.read(trainning_task_file)n = int(config['network']['input_nodes']) # input vector sizeK = int(config['network']['output_nodes']) # output vector sizelam = float(config['network']['lambda'])#隐层规模 用逗号分开,类似 ”16,16,13“ hidden_layer_size = config['network']['hidden_layer_size'] #分离字符reobj = re.compile('[\s,\"]')ls_array = reobj.split(hidden_layer_size);ls_array = [item for item in filter(lambda x:x != '', ls_array)] #删空白#隐层个数hidden_layer_elems = len(ls_array);#转为整形,并计入输出层 ns_array = []for idx in range(0,hidden_layer_elems) : ns_array.append(int(ls_array[idx]))#Output is the last layer, append to lastns_array.append(K)#总层数(含有输出层)total_layer_size = len(ns_array)#--------------------------------------------------------------#create graphgraph = tf.Graph()with graph.as_default(): with tf.name_scope('network'): with tf.name_scope('input'): s = [n] a = [tf.placeholder(tf.float32,[None,s[0]],name="in")] W = [] b = [] z = [] punish = tf.constant(0.0) for idx in range(0,total_layer_size) : with tf.name_scope('layer'+str(idx+1)): s.append(int(ns_array[idx])) W.append(tf.Variable(tf.random_uniform([s[idx],s[idx+1]],0,1),name='W'+str(idx+1))) b.append(tf.Variable(tf.random_uniform([1],0,1),name='b'+str(idx+1))) z.append(tf.matmul(a[idx],W[idx]) + b[idx]*tf.ones([1,s[idx+1]],name='z'+str(idx+1))) a.append(tf.nn.tanh(z[idx],name='a'+str(idx+1))) with tf.name_scope('regular'): punish = punish + tf.reduce_sum(W[idx]**2) * lam #-------------------------------------------------------------- with tf.name_scope('loss'): y_ = tf.placeholder(tf.float32,[None,K],name="tr_out") loss = tf.reduce_mean(tf.square(a[total_layer_size]-y_),name="loss") + punish with tf.name_scope('trainning'): optimizer = tf.train.AdamOptimizer(name="opt") train = optimizer.minimize(loss,name="train") init = tf.global_variables_initializer() #save graph to Disk saver = tf.train.Saver()#--------------------------------------------------------------### create tensorflow structure end ###sess = tf.Session(graph=graph)sess.run(init) # Very important#后续紧邻下个代码段
程序中,还包括了正则化参数的引入、代价函数loss与持久化保存器saver。
2. 读取文本文件并训练
训练样本以文本文件提供,每行一个样本。
每行共N+K个元素,前N个为输入(特征),后K个为理论应该得到的输出。
这部分的代码如下:
#紧接着上个代码段file_deal_times = int(config['performance']['file_deal_times'])trunk = int(config['performance']['trunk'])train_step = int(config['performance']['train_step'])iterate_times = int(config['performance']['iterate_times'])#trainningx_data = np.zeros([trunk,n]).astype(np.float32)#read n features and K outputsy_data = np.zeros([trunk,K]).astype(np.float32)plot_x = []plot_y = []for rc in range(file_deal_times): with open(trainning_input_file, 'rt') as ftr: while 1: lines = ftr.readlines() if not lines: #reach end of file, run trainning for tail items if there is some. if (total_red>0): for step in range(iterate_times): sess.run(train,feed_dict={a[0]:x_data[0:min(total_red,trunk)+1],y_:y_data[0:min(total_red,trunk)+1]}) break line_count = len(lines) for lct in range(line_count): x_arr = reobj.split(lines[lct]); x_arr = [item for item in filter(lambda x:x != '', x_arr)] #remove null strings for idx in range(n) : x_data[total_red % trunk,idx] = float(x_arr[idx]) for idx in range(K) : y_data[total_red % trunk,idx] = float(x_arr[idx+n]) total_red = total_red + 1 #the trainning set run trainning if (total_red % train_step == 0): #trainning for step in range(iterate_times): sess.run(train,feed_dict={a[0]:x_data[0:min(total_red,trunk)+1],y_:y_data[0:min(total_red,trunk)+1]}) #print loss lss = sess.run(loss,feed_dict={a[0]:x_data[0:min(total_red,trunk)+1],y_:y_data[0:min(total_red,trunk)+1]}) print(rc,total_red,lss) plot_x.append(total_red) plot_y.append(lss) if (lss<0.0001): break;mpl.plot(plot_x,plot_y)#saving# 保存,这次就可以成功了saver.save(sess,model_path+'/model.ckpt')#文件结束
3 进行训练
运行程序
runfile('./gn_first_training.py', wdir='./')0 1024 0.1228850 2048 0.1036570 3072 0.1017910 4096 0.06332280 5120 0.02412280 6144 0.02069490 7168 0.02219840 8192 0.0193870 9216 0.01595680 10240 0.0162130 11264 0.01600020 12288 0.01305311 13312 0.006161811 14336 0.005293991 15360 0.003866981 16384 0.003411171 17408 0.003801841 18432 0.00388291 19456 0.003132161 20480 0.003075371 21504 0.00313111 22528 0.00320051 23552 0.002812871 24576 0.00197442
训练后,在文件夹下会保存训练结果,供后续识别、增量训练使用。
下一篇,我们略微修改代码,实现从磁盘读取训练结果并继续训练。
- 机器学习小试(7)使用TensorFlow跑通一个通用增量学习流程-根据配置文件创建全连接网络
- 机器学习小试(6)使用TensorFlow跑通一个通用增量学习流程-设计配置文件
- 机器学习小试(8)使用TensorFlow跑通一个通用增量学习流程-增量学习
- 机器学习小试(9)使用TensorFlow跑通一个通用增量学习流程-测试与应用
- 基于Tensorflow的机器学习(5) -- 全连接神经网络
- 【机器学习】动手写一个全连接神经网络(一)
- TensorFlow学习笔记(4)----完整的工程示例:全连接前馈网络识别MNIST
- 【机器学习】Tensorflow基本使用
- 【机器学习】Tensorflow基本使用
- Tensorflow学习(4)池化层和全连接层
- tensorflow学习:池化层(pooling)和全连接层(dense)
- 【机器学习】动手写一个全连接神经网络(二):线性回归
- 【机器学习】动手写一个全连接神经网络(三):分类
- 机器学习小试(1)TensorFlow的第一个程序
- TensorFlow 全连接网络实现
- Tensorflow v1.0.1中机器学习随机森林算法的一个小的改变
- 一步步学习SPD2010--第七章节--使用BCS业务连接服务(8)--创建配置文件页面
- 通用机器学习流程与问题解决架构模板
- 七周第五次课 2017.12.1 iptables规则备份和恢复、firewalld的9个zone、firewalld关于zone的操作、firewalld关于service的操作
- 为什么现在机器学习如此火爆
- Hibernate的学习之路九(主键的生成策略)
- SQLiteOpenHelper 文件路径
- Ball
- 机器学习小试(7)使用TensorFlow跑通一个通用增量学习流程-根据配置文件创建全连接网络
- sql server与eclipse及JAVA中sql语言的写法
- 自定义view进度条案例
- 【备忘】2017年最新springboot开发校园商铺平台视频教程
- 火狐访问所有HTTPS网站显示连接不安全
- pip安装mysql-python报mysql_config: not found错
- OpenStack-M版(Mitaka)搭建基于(Centos7.2)+++十、Openstack对象存储服务(swift)下
- 蓝桥杯练习题之十六进制转十进制
- 在Android中使用MD5