TF Learn入门 —— 简介
来源:互联网 发布:linux安卓环境搭建 编辑:程序博客网 时间:2024/06/14 04:12
为什么选择 TensorFlow?
TensorFlow 为不同机器学习应用提供了良好的框架。
它将继续沿着分布式和基本管道式机器两个方向发展。
为什么选择 TensorFlow Learn?
更好的从 scikit-learn 单一机器学习过渡到更广的构建不同形态机器学习模型。你可以从使用 fit/predict 开始过渡到使用 TensorFlow API。
提供一组能与现存代码更好融合的参考模型。
安装
安装 TensorFlow 然后 import learn 从 from tensorflow.contrib.learn 或者使用 tf.contrib.learn
也可选择安装 scikit-learn 或者 pandas 以获取更多功能。
使用
把数据传递进 Estimator 前应将数据调校到均值为0或者单位标准差,Stochastic Gradient Descent 当变量尺度不同时无法保证正确下降回归。
类变量在传递进 Estimator 前应做处理。
线性分类
import tensorflow.contrib.learn.python.learn as learnfrom sklearn import datasets, metricsiris = datasets.load_iris()feature_columns = learn.infer_real_valued_columns_from_input(iris.data)classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns)classifier.fit(iris.data, iris.target, steps=200, batch_size=32)iris_prediction = list(classifier.predict(iris.data, as_iterable=True)score = metircs.accuracy_score(iris.target, iris_prediction)print('Accuracy: %f' % score)
线性回归
import tensorflow.contrib.learn.python.learn as learnfrom sklearn import datasets, metrics, preprocessingboston = datasets.load_boston()x = preprocessing.StandardScaler().fit_transform(boston.data)feature_columns = learn.infer_real_valued_columns_from_input(x)regressor = learn.LinearRegressor(feature_columns=feature_columns)regressor.fit(x, boston.target, steps=200, batch_size=32)boston_predictions = list(regressor.predict(x, as_iterable=True)score = metircs.mean_squared_error(boston_predictions, boston.target)print('MSE: %f' % score)
深度神经网络
import tensorflow.contrib.learn.python.learn as learnfrom sklearn import datasets, metricsiris = datasets.load_iris()feature_columns = learn.infer_real_valued_columns_from_input(iris.data)classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3, feature_columns=feature_columns)classifier.fit(iris.data, iris.target, steps=200, batch_size=32)iris_predictions = list(classifier.predict(iris.data, as_iterable=True)score = metrics.accuracy_score(iris.target, iris_predictions)print('Accuracy: %f' % score)
定制模型
from sklearn import datasetsfrom sklearn import metricsimport tensorflow as tfimport tensorflow.contrib.layers.python.layers as layersimport tensorflow.contrib.learn.python.learn as learniris = datasets.load_iris()def my_model(features, label):labels = tf.one_hot(labels, 3, 1, 0) # on-value of 1 for each one-hot vector of length 3features = layers.stack(features, layers.fully_connected, [10, 20, 10])prediction, loss = (tf.contrib.learn.model.logistic_regression(features, labels))train_op = tf.contrib.layers.optimize_loss(loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',learning_rate=0.1)return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_opclassifier = learn.Estimator(model_fn=my_model)classifier.fit(iris.data, iris.traget, steps=1000)y_predicted = [p['class'] for p in classifier.predict(iris.data, as_iterable=True)]score = metrics.accuracy_score(iris.target, y_predicted)print('Accuracy: {0:f}'.format(score))
保存和恢复模型
每个评估器都支持 model_dir 参数,接受文件夹路径以保存模型
classifier = learn.DNNClassifier(..., model_dir='/tmp/my_model')
如果对同一个 Estimator 进行多次 fit 操作,当上一个操错结束后,训练为继续进行。
把检查点恢复到一个新的 Estimator ,只要把同样的 model_dir 参数传递给它。
图形化查看概要
如果在 Estimator 里提供了 model_dir 参数, TensorFlow 会将 loss 和变量的分布图概要写进这一目录(你也可以在定制模型操作中通过请求概要操作加入定制概要)
在 TensorBoard 中查看概要,
tensorboard --logdir=/tmp/tf_examples/my_model_1
然后载入显示 URL
- TF Learn入门 —— 简介
- TF Learn入门 —— 简单使用举例
- TF Learn入门 —— 稍复杂使用举例
- tf.contrib.learn快速入门
- TensorFlow-4: tf.contrib.learn 快速入门
- 【译】scikit-learn入门简介
- 深度学习笔记——深度学习框架TensorFlow(四)[高级API tf.contrib.learn]
- 深度学习笔记——深度学习框架TensorFlow(十)[Creating Estimators in tf.contrib.learn]
- TF.Learn组件
- tf入门
- tf 入门
- 深度学习笔记——深度学习框架TensorFlow(八)[Logging and Monitoring Basics with tf.contrib.learn]
- 深度学习笔记——深度学习框架TensorFlow(九)[Building Input Functions with tf.contrib.learn]
- tf.contrib.learn.preprocessing.VocabularyProcessor
- 7.1 Scikit-learn库简介及快速入门
- Kaggle入门——使用scikit-learn解决DigitRecognition问题
- Kaggle入门——使用scikit-learn解决DigitRecognition问题
- Kaggle入门——使用scikit-learn解决DigitRecognition问题
- 双缓冲技术
- Java Web 自定义MVC框架
- jxl导入/导出excel
- @RequestParam @RequestBody @ResponseBody区别 (1)
- ERROR 1045 (28000): Access denied for user 'root'@'localhost' (using password: YES) 解决
- TF Learn入门 —— 简介
- linux yum命令详解 yum(全称为 Yellow dog Updater, Modified)是一个在Fedora和RedHat以及SUSE中的Shell前端软件包管理器。基於RPM包管理,能
- if语句的陷阱
- 对于js中网络接口websocket,二进制数组arraybuffer,视图对象dataview学习记录。
- Struts2笔记10 向值栈放入或获取数据
- 批处理bat命令--获取当前盘符、当前目录
- Netty物联网高并发系统第一季
- 如何在把微信公众号生成链接
- Linux常用命令(六)——其它常用命令(未拓展)