tensorflow学习day2简单监督学习模型及用tf.train.Saver实现检查点恢复
来源:互联网 发布:李世宏dota2 知乎 编辑:程序博客网 时间:2024/05/19 23:54
监督模型的训练数据是带有结果标签的,例如针对电信流失预警,我们已经事先知道每个人的所有消费属性特征和个人信息,以及他是否流失,通过建立流失与否和特征之间的模型,我们可以根据既有特征得到他是否流失的预测,再通过最小化预测和真实值之间的差距,不断优化模型参数,最终得到一个可用来预测的模型。
总结一下,监督模型的框架应该分为以下几步:
(1)首先对模型参数初始化,一般采用随机数赋值,对于较简单的模型,可以将参数初始值设为零
(2)读取训练数据,包括样本的特征和标签,一般在模型读入数据前会将数据打乱
(3)在训练数据上执行预测模型,每个样本得到一个预测的标签
(4)计算损失,即预测值和实际标签值的差距
(5)调整模型参数。这一步是实际学习过程,给定损失函数,通过大量训练改善模型参数的值,将损失最小化。常见的优化方法是随机梯度下降。
(6)最后在测试集上对模型的预测能力进行评估
1. 对于有监督学习问题,通用的模型训练和评估代码框架可以遵从如下:
# 有监督学习框架import tensorflow as tfdef inference(x): # 计算模型在x上的输出,返回结果def loss(x, y): # 根据x对应的实际y值和模型给出的y值计算损失def inputs(): # 读取训练数据x和ydef train(total_loss): # 依据计算的总损失训练或调整模型参数def evaluate(sess, x, y): # 对训练得到的模型进行评估with tf.Session() as sess: tf.global_variables_initializer().run() x,y = inputs() total_loss = loss(x,y) train_op = train(total_loss) coord = tf.train.Coordinator() # 可以在发生错误的情况下正确地关闭这些线程 threads = tf.train.start_queue_runners(sess=sess,coord=coord) # 函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本 training_steps = 1000 for step in range(training_steps): sess.run([train_op]) if step % 10 == 0: print("loss:",sess.run([total_loss])) evaluate(sess,x,y) coord.request_stop() coord.join(threads)
关于读数据,更多参考:http://blog.sina.com.cn/s/blog_e22771170102wcfv.html
2. 因为模型训练是通过多个训练周期对变量进行迭代的,所以在训练中创建checkpoint周期性保存变量,有利于我们在训练异常中断后,从最近的检查点恢复并继续之前的训练。这里简要介绍下tf.train.Saver()的用法
# checkpoint# 创建一个Saver对象saver = tf.train.Saver()# 启动Session,在训练过程中阶段性创建checkpoint,保存变量值with tf.Session() as sess: # ...... for step in range(training_steps): sess.run([train_op]) if step % 1000 == 0: saver.save(sess, 'model', global_step=step) # ... saver.save(sess, 'model', global_step=training_steps)
tf.train.get_checkpoint_state可以用来检查是否有保存的checkpoint, tf.train.Saver.restore 负责恢复变量的值
# 启动Session,在训练过程中阶段性创建checkpoint,保存变量值with tf.Session() as sess: # ...... start_step = 0 # 检查是否有checkpoint checkpoint = tf.train.get_checkpoint_state(os.path.dirname(__file__)) if checkpoint and checkpoint.model_checkpoint_path: saver.restore(sess, checkpoint.model_checkpoint_path) start_step = int(checkpoint.model_checkpoint_path.rsplit('-',1))[1] for step in range(training_steps): sess.run([train_op]) if step % 1000 == 0: saver.save(sess, 'model', global_step=step) # ... saver.save(sess, 'model', global_step=training_steps)
3. 线性回归
# 线性回归w = tf.Variable(tf.zeros([2,1]), name='weights')b = tf.Variable(0., name='bias')def inference(x): return tf.matmul(x,w)+bdef loss(x, y): y_hat = inference(x) return tf.reduce_sum(tf.squared_difference(y,y_hat))def inputs(): x = tf.random_normal([50,2], mean=0.0, stddev=1.0) w = tf.constant([[0.3], [7]]) y = tf.matmul(x, w) + 2 return x,ydef train(total_loss): learning_rate = 0.0001 return tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)with tf.Session() as sess: tf.global_variables_initializer().run() x,y = inputs() total_loss = loss(x, y) train_op = train(total_loss) training_steps = 1000 for step in range(training_steps): sess.run([train_op]) if step % 10 == 0: print('loss:', sess.run(total_loss)) print('wb:', sess.run([w, b]))
logistic模型是一个二元分类模型,经常用来回答yes or no问题,例如流失与否,欺诈与否。其核心是对概率建模,将事件的对数发生比用一个多元线性函数拟合。
logit(x)=log(p/(1-p))= a+bx1+cx2 P是yes的概率
p=f(x)=1/(1+exp(-x))
f(x)被称为sigmoid函数
对于二元分类模型,可以用交叉熵作为损失函数。
loss=sum(yi*log(yi_hat)+(1-yi)*log(1-yi_hat)) 如此当实际y=1,而预测y=0(vise versa)时,loss会无穷大
# logistic 回归模型w = tf.Variable(tf.zeros([4,1]), name="weights")b = tf.Variable(0., name='bias')def inference(x): y = tf.matmul(x, w)+b return tf.sigmoid(y)def loss(x,y): return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(inference(x), y))
注意sigmoid函数输出的是概率。
- tensorflow学习day2简单监督学习模型及用tf.train.Saver实现检查点恢复
- tensorflow学习——tf.train.Supervisor()与tf.train.saver()
- tensorflow关于tf.train.Saver()
- Tensorflow的模型保存和读取tf.train.Saver
- TensorFlow入门(九)使用 tf.train.Saver()保存模型
- TensorFlow入门(九)使用 tf.train.Saver()保存模型
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
- 【TensorFlow】模型持久化tf.train.Saver—上(八)
- 【TensorFlow】模型持久化tf.train.Saver—下(九)
- tf.train.Saver
- tf.train.Saver
- class tf.train.Saver
- Tensorflow学习(6)模型的保存与恢复(saver)
- tensorflow 1.0之tf.train.Saver 文档翻译
- TensorFlow学习--Saver
- tensorflow学习——tf.floor与tf.train.batch
- tensorflow 模型的保存与恢复(Saver)
- tensorflow学习笔记(三十四):Saver(保存与加载模型)
- mate 标签中属性 以及内核选择
- 网络协议与端口
- 国内开源镜像站点汇总2017年10月版
- bzoj 4429: [Nwerc2015] Elementary Math小学数学 网络流
- go里面select-case和time.Ticker的使用注意事项
- tensorflow学习day2简单监督学习模型及用tf.train.Saver实现检查点恢复
- 云计算建立标准才能给创业者公平机会
- linux系统下安装chrome
- 经验:《王者荣耀》技术总监分享背后技术
- VS2013搭建OpenGL环境
- axios 中文文档 翻译
- MergeSort分治实现代码(java)
- java下对excel文件的上传
- Python学习[01]