TensorFlow 0.12 Estimators Models Layers学习笔记
来源:互联网 发布:淘宝宝贝排名查询软件 编辑:程序博客网 时间:2024/05/20 21:43
TensorFlow在tensorflow.contrib包中有很多封装好的工具,最近在学习中用到了一些模块,在这里做一些笔记。
Estimators
位于tensorflow.contrib.learn
引用 深度解析TensorFlow组件Estimator:构建自定义Estimator 里面的观点,其作用是:
- 是TensorFlow训练和评估模块的抽象和基类。它利用graph_actions.py的隐藏逻辑,提供像fit(),partial_fit(),evaluate()和predict()的基本功能;
- 为monitors,checkpointing等初始化设置,并提供了构建和评估自定义模块的大部分逻辑;
基本上可以认为类似于我们一般应用开发中的Application类,定义了App的一些基本规范和逻辑,创建一个TensorFlow的训练程序就是从Estimator开始(当然,并不是一定要用这个才行,它只是相当于某种框架,定义了开发的一种标准方式)
定义Estimator
init(model_fn=None, model_dir=None, config=None, params=None, feature_engineering_fn=None)
- model_fn: 模型定义,定义了train, eval, predict的实现
- model_dir: log文件和训练参数的保存目录
- config: Configuration object
- params: dict of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types.
- feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of input_fn and returns features and labels which will be fed into model_fn. Please check model_fn for a definition of features and labels.
一般情况下model_fn, model_dir两个参数是必须的.
训练模型
fit(self, x, y, input_fn=None, batch_size=128, steps=None, max_steps=None, monitors=None)
从文档看,google推荐用input_fn的方式来提供训练数据, 当然老的传x, y的方式也能用,会提供一个内置的input_fn来包装一下。
评估模型
evaluate(
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
steps=None, metrics=None, name=None)
老实说,不太明白这个怎么用。
预测
predict(x=None, input_fn=None, batch_size=None, outputs=None, as_iterable=True)
这个方法的输出格式取决于model_fn返回的prediction部分的定义,而且要通过迭代方式来访问。
还要其他方法,暂时用不到。
Estimator部分最重要的就是model_fn的定义,这个方法的申明是:
model_fn(features, labels, mode, params)
- features: 样本数据的x
- labels: 样本数据的y
- mode: 模式 有3种TRAIN/EVAL/INFER,根据这个参数,model_fn可以做特定的处理
- params: mode_fn需要的其他参数,dict数据结构
Estimator里面还有一个叫SKCompat的类,如果使用x,y而不是input_fn来传参数的形式,需要用这个类包装一下:
est = SKCompat(Estimator(…))
TensorFlow还封装了一些针对特定训练算法的Estimator,也就是说自己年model_fn都不用写了,具体参见Estimators目录下的文件。
(待续)
- TensorFlow 0.12 Estimators Models Layers学习笔记
- TensorFlow学习笔记12----Creating Estimators in tf.contrib.learn
- tensorflow-estimators
- 深度学习笔记——深度学习框架TensorFlow(十)[Creating Estimators in tf.contrib.learn]
- TensorFlow学习笔记7----Large-scale Linear Models with TensorFlow
- tensorflow 学习笔记: recurrent models of visual attention
- 【TensorFlow】理解 Estimators 和 Datasets
- 学习笔记:Creating Estimators in tf.contrib.learn
- caffe学习笔记tutorial:Layers
- models数据模型学习笔记
- Tensorflow Models
- Structuring Your TensorFlow Models-翻译与学习
- Introduction to TensorFlow Datasets and Estimators
- tensorflow学习——tf.layers.batch_normalization/tf.nn.batch_normalization/tf.contrib.layers.batch_norm
- tensorflow学习:tf.nn.conv2d 和 tf.layers.conv2d
- tensorflow.layers.batch_normalization使用方法
- tensorflow slim layers
- tensorflow编程: Layers (contrib)
- 线性求子序列最大平均值
- centos 安装Visual Studio Code
- 整理wmic使用,不重启变环境变量 .
- 关于java编写画图板的思考
- yii2框架-yii2自身的自动加载(三)
- TensorFlow 0.12 Estimators Models Layers学习笔记
- mysql 5.7 ERROR 1045 (28000): Access denied for user 'root'@'localhost'
- 屏幕适配基本概念
- 随机读写文件内容之RandomAccessFile类相关
- Haar Adaboost 检测自定义目标(视频车辆检测算法代码)
- PAT---B1041. 考试座位号(15)
- 用JSP输出九九乘法表
- 【ife】任务二十一:基础JavaScript练习(四)
- 文章标题