TensorFlow 0.12 Estimators Models Layers学习笔记

来源:互联网 发布:淘宝宝贝排名查询软件 编辑:程序博客网 时间:2024/05/20 21:43

TensorFlow在tensorflow.contrib包中有很多封装好的工具,最近在学习中用到了一些模块,在这里做一些笔记。

Estimators

位于tensorflow.contrib.learn
引用 深度解析TensorFlow组件Estimator:构建自定义Estimator 里面的观点,其作用是:

  • 是TensorFlow训练和评估模块的抽象和基类。它利用graph_actions.py的隐藏逻辑,提供像fit(),partial_fit(),evaluate()和predict()的基本功能;
  • 为monitors,checkpointing等初始化设置,并提供了构建和评估自定义模块的大部分逻辑;

基本上可以认为类似于我们一般应用开发中的Application类,定义了App的一些基本规范和逻辑,创建一个TensorFlow的训练程序就是从Estimator开始(当然,并不是一定要用这个才行,它只是相当于某种框架,定义了开发的一种标准方式)

定义Estimator

init(model_fn=None, model_dir=None, config=None, params=None, feature_engineering_fn=None)

  • model_fn: 模型定义,定义了train, eval, predict的实现
  • model_dir: log文件和训练参数的保存目录
  • config: Configuration object
  • params: dict of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types.
  • feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of input_fn and returns features and labels which will be fed into model_fn. Please check model_fn for a definition of features and labels.

一般情况下model_fn, model_dir两个参数是必须的.

训练模型

fit(self, x, y, input_fn=None, batch_size=128, steps=None, max_steps=None, monitors=None)
从文档看,google推荐用input_fn的方式来提供训练数据, 当然老的传x, y的方式也能用,会提供一个内置的input_fn来包装一下。

评估模型

evaluate(
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
steps=None, metrics=None, name=None)

老实说,不太明白这个怎么用。

预测

predict(x=None, input_fn=None, batch_size=None, outputs=None, as_iterable=True)

这个方法的输出格式取决于model_fn返回的prediction部分的定义,而且要通过迭代方式来访问。

还要其他方法,暂时用不到。

Estimator部分最重要的就是model_fn的定义,这个方法的申明是:

model_fn(features, labels, mode, params)
- features: 样本数据的x
- labels: 样本数据的y
- mode: 模式 有3种TRAIN/EVAL/INFER,根据这个参数,model_fn可以做特定的处理
- params: mode_fn需要的其他参数,dict数据结构

Estimator里面还有一个叫SKCompat的类,如果使用x,y而不是input_fn来传参数的形式,需要用这个类包装一下:

est = SKCompat(Estimator(…))

TensorFlow还封装了一些针对特定训练算法的Estimator,也就是说自己年model_fn都不用写了,具体参见Estimators目录下的文件。

(待续)

0 0
原创粉丝点击