TensorFlow-5: 用 tf.contrib.learn 来构建输入函数
来源:互联网 发布:nginx负载均衡策略 编辑:程序博客网 时间:2024/06/14 10:06
学习资料:
https://www.tensorflow.org/get_started/input_fn
对应的中文翻译:
http://studyai.site/2017/03/06/%E3%80%90Tensorflow%20r1.0%20%E6%96%87%E6%A1%A3%E7%BF%BB%E8%AF%91%E3%80%91%E9%80%9A%E8%BF%87tf.contrib.learn%E6%9D%A5%E6%9E%84%E5%BB%BA%E8%BE%93%E5%85%A5%E5%87%BD%E6%95%B0/
今天学习用 tf.contrib.learn 来建立 input funciton, 并用 DNN 对 Boston Housing 数据集进行回归预测。
问题:
- 给一组波士顿房屋价格数据,要用神经网络回归模型来预测房屋价格的中位数
- 数据集可以从官网教程下载:
https://www.tensorflow.org/get_started/input_fn - 它包括以下特征:
- 我们需要预测的是MEDV这个标签,以每一千美元为单位
一共有 5 步:
- 导入 CSV 格式的数据集
- 建立神经网络回归模型
- 用训练数据集训练模型
- 评价模型的准确率
- 对新样本数据进行分类
代码:
地址:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/input_fn/boston.py
"""DNNRegressor with custom input_fn for Housing dataset."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport itertoolsimport pandas as pdimport tensorflow as tftf.logging.set_verbosity(tf.logging.INFO)COLUMNS = ["crim", "zn", "indus", "nox", "rm", "age", "dis", "tax", "ptratio", "medv"]FEATURES = ["crim", "zn", "indus", "nox", "rm", "age", "dis", "tax", "ptratio"]LABEL = "medv"def input_fn(data_set): feature_cols = {k: tf.constant(data_set[k].values) for k in FEATURES} labels = tf.constant(data_set[LABEL].values) return feature_cols, labelsdef main(unused_argv): # Load datasets training_set = pd.read_csv("boston_train.csv", skipinitialspace=True, skiprows=1, names=COLUMNS) test_set = pd.read_csv("boston_test.csv", skipinitialspace=True, skiprows=1, names=COLUMNS) # Set of 6 examples for which to predict median house values prediction_set = pd.read_csv("boston_predict.csv", skipinitialspace=True, skiprows=1, names=COLUMNS) # Feature cols feature_cols = [tf.contrib.layers.real_valued_column(k) for k in FEATURES] # Build 2 layer fully connected DNN with 10, 10 units respectively. regressor = tf.contrib.learn.DNNRegressor(feature_columns=feature_cols, hidden_units=[10, 10], model_dir="/tmp/boston_model") # Fit regressor.fit(input_fn=lambda: input_fn(training_set), steps=5000) # Score accuracy ev = regressor.evaluate(input_fn=lambda: input_fn(test_set), steps=1) loss_score = ev["loss"] print("Loss: {0:f}".format(loss_score)) # Print out predictions y = regressor.predict(input_fn=lambda: input_fn(prediction_set)) # .predict() returns an iterator; convert to a list and print predictions predictions = list(itertools.islice(y, 6)) print("Predictions: {}".format(str(predictions)))if __name__ == "__main__": tf.app.run()
今天主要的知识点就是输入函数
在上面的代码中我们可以看到,输入数据时用的是 pandas,可以直接读取 CSV 文件
为了识别数据集中哪些是列,哪些是特征,哪些是预测标签,需要把这三者定义出来
在定义神经网络回归模型时,我们建立一个具有两层隐藏层的神经网络,每一层具有 10 个神经元节点,
接下来就是建立输入函数,它的作用就是把输入数据传递给回归模型,它可以接受 pandas 的 Dataframe 结构,并将特征和标签列作为 Tensors 返回
在训练时,只需要把训练数据集传递给输入函数,用 fit 迭代5000步
评价模型时,也是将测试数据集传递给输入函数,再用 evaluate
预测时,同样将预测数据集传递给输入函数
关于 输入函数:
昨天学到读取 CSV 文件的方法适用于不需要对原来的数据有什么操作的时候
但是当需要对数据进行特征工程时,我们就需要有一个输入函数来把数据的预处理给封装起来,再传递给模型
输入函数的基本框架:
def my_input_fn(): # Preprocess your data here... # ...then return 1) a mapping of feature columns to Tensors with # the corresponding feature data, and 2) a Tensor containing labels return feature_cols, labels
输入函数必须返回下面两种值:
feature_cols
:是一个字典,key 就是特征列的名字,value 就是 tensor,包含了相应的数据
labels
:返回包含标签数据的 tensor,即所想要预测的目标
如果特征/标签数据存在pandas数据帧中或numpy数组中,那么需要将其转换为Tensor,然后从 input_fn 中返回。
对于稀疏数据
大多数值为0的数据,应该填充一个 SparseTensor,
下面例子,就是定义了一个具有3行和5列的二维 SparseTensor。在 [0,1] 上的元素的值为 6,[2,4] 上的元素值为 0.5,其他值为 0:
sparse_tensor = tf.SparseTensor(indices=[[0,1], [2,4]], values=[6, 0.5], dense_shape=[3, 5])
[[0, 6, 0, 0, 0] [0, 0, 0, 0, 0] [0, 0, 0, 0, 0.5]]
推荐阅读
历史技术博文链接汇总
也许可以找到你想要的
- TensorFlow-5: 用 tf.contrib.learn 来构建输入函数
- 使用tf.contrib.learn构建输入函数
- 05:Tensorflow高级API的进阶--利用tf.contrib.learn建立输入函数
- tensorflow之tf.contrib.learn Quickstart
- TensorFlow-4: tf.contrib.learn 快速入门
- TensorFlow学习笔记6----tf.contrib.learn Quickstart
- [TensorFlow实战练习]3-高层API-tf.contrib.learn练习
- tensorflow学习笔记(六):TF.contrib.learn大杂烩
- TensorFlow学习笔记12----Creating Estimators in tf.contrib.learn
- tensorflow中tf.contrib.learn.preprocessing.VocabularyProcessor理解
- tensorflow学习笔记十四:TF官方教程学习 tf.contrib.learn Quickstart
- tf.contrib.learn.preprocessing.VocabularyProcessor
- tf.contrib.learn快速入门
- tensorflow学习笔记十五:tensorflow官方文档学习 Logging and Monitoring Basics with tf.contrib.learn
- TensorFlow学习笔记10----Logging and Monitoring Basics with tf.contrib.learn
- TensorFlow学习笔记11----Building Input Functions with tf.contrib.learn
- 深度学习笔记——深度学习框架TensorFlow(四)[高级API tf.contrib.learn]
- 深度学习笔记——深度学习框架TensorFlow(十)[Creating Estimators in tf.contrib.learn]
- JS核心系列:浅谈原型对象和原型链
- 数据样本的选择方法
- 【Shell】删除文档中重复内容
- 功能如此齐全 也许是最精良的免费数据可视化软件
- BPS流程怎样设计业务与流程的结合
- TensorFlow-5: 用 tf.contrib.learn 来构建输入函数
- 为Vue2集成UIkit
- IOS OC声明变量在@interface括号中与使用@property的区别
- 原生和jQuery的ajax用法
- hadoop基石HDFS
- Java 创建二叉树并遍历
- WebStorm 配置SVN、启动浏览器、跨域
- 美团一面
- Java序列化中transient修饰符的作用