tensorflow(4)
来源:互联网 发布:彩虹六号软件 编辑:程序博客网 时间:2024/05/29 02:11
tf.estimator
A custom model
tf.estimator不会将您锁定到其预定义的模型中。假设我们想创建一个没有内置到TensorFlow中的自定义模型。我们仍然可以保留高频抽象的数据集,给数据,训练等。为了说明,我们将展示如何使用我们对低级别TensorFlow API的知识来实现我们自己的LinearRegressor的等效模型。
要定义一个与tf.estimator一起工作的自定义模型,我们需要使用tf.estimator.Estimator。tf.estimator.LinearRegressor实际上是tf.estimator.Estimator的一个子类。我们只需要提供一个函数model_fn来告诉tf.estimator如何评估预测,训练步骤和损失,而不是用Estimator的子分类。代码如下:
import numpy as npimport tensorflow as tf# Declare list of features, we only have one real-valued featuredef model_fn(features, labels, mode): # Build a linear model and predict values W = tf.get_variable("W", [1], dtype=tf.float64) b = tf.get_variable("b", [1], dtype=tf.float64) y = W*features['x'] + b # Loss sub-graph loss = tf.reduce_sum(tf.square(y - labels)) # Training sub-graph global_step = tf.train.get_global_step() optimizer = tf.train.GradientDescentOptimizer(0.01) train = tf.group(optimizer.minimize(loss), tf.assign_add(global_step, 1)) # EstimatorSpec connects subgraphs we built to the # appropriate functionality. return tf.estimator.EstimatorSpec( mode=mode, predictions=y, loss=loss, train_op=train)estimator = tf.estimator.Estimator(model_fn=model_fn)# define our data setsx_train = np.array([1., 2., 3., 4.])y_train = np.array([0., -1., -2., -3.])x_eval = np.array([2., 5., 8., 1.])y_eval = np.array([-1.01, -4.1, -7., 0.])input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_train}, y_train, batch_size=4, num_epochs=None, shuffle=True)train_input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_train}, y_train, batch_size=4, num_epochs=1000, shuffle=False)eval_input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_eval}, y_eval, batch_size=4, num_epochs=1000, shuffle=False)# trainestimator.train(input_fn=input_fn, steps=1000)# Here we evaluate how well our model did.train_metrics = estimator.evaluate(input_fn=train_input_fn)eval_metrics = estimator.evaluate(input_fn=eval_input_fn)print("train metrics: %r"% train_metrics)print("eval metrics: %r"% eval_metrics)
运行结果:
train metrics: {'loss': 1.1786181e-11, 'global_step': 1000}eval metrics: {'loss': 0.010100437, 'global_step': 1000}
阅读全文
0 0
- tensorflow(4)
- Tensorflow学习笔记4:分布式Tensorflow
- ubuntu1.4安装tensorflow
- 浅入浅出TensorFlow 4
- TensorFlow学习日记4
- TensorFlow基础知识4-变量
- Learning Tensorflow (4)
- tensorflow基础使用4
- TensorFlow学习笔记4
- 宣布 TensorFlow r1.4
- [TensorFlow]学习手记 4
- 4、TensorFLow 数学运算
- tensorflow
- TensorFlow
- TensorFlow
- tensorflow
- tensorflow
- tensorflow
- hello world
- JAVA项目中发布WebService服务——调用方式
- [Unity 热更新]tolua原理及实践
- 堆栈与动态分配内存空间
- java基础之IO(重点)
- tensorflow(4)
- instr函数
- 【NOIP模拟】 (11.3) T2 排列
- “小微”企业的集体“大”商标
- Alibaba开源Fastjson讲解和应用
- Rxjava和retorfit的混合使用
- 二分查找的递归和非递归实现
- Twitter 产品设计师王源专访
- (4.1.36)android Graphics 图形学解析