tensorflow(4)

来源:互联网 发布:彩虹六号软件 编辑:程序博客网 时间:2024/05/29 02:11

tf.estimator

A custom model

    tf.estimator不会将您锁定到其预定义的模型中。假设我们想创建一个没有内置到TensorFlow中的自定义模型。我们仍然可以保留高频抽象的数据集,给数据,训练等。为了说明,我们将展示如何使用我们对低级别TensorFlow API的知识来实现我们自己的LinearRegressor的等效模型。

    要定义一个与tf.estimator一起工作的自定义模型,我们需要使用tf.estimator.Estimator。tf.estimator.LinearRegressor实际上是tf.estimator.Estimator的一个子类。我们只需要提供一个函数model_fn来告诉tf.estimator如何评估预测,训练步骤和损失,而不是用Estimator的子分类。代码如下:

import numpy as npimport tensorflow as tf# Declare list of features, we only have one real-valued featuredef model_fn(features, labels, mode):  # Build a linear model and predict values  W = tf.get_variable("W", [1], dtype=tf.float64)  b = tf.get_variable("b", [1], dtype=tf.float64)  y = W*features['x'] + b  # Loss sub-graph  loss = tf.reduce_sum(tf.square(y - labels))  # Training sub-graph  global_step = tf.train.get_global_step()  optimizer = tf.train.GradientDescentOptimizer(0.01)  train = tf.group(optimizer.minimize(loss),                   tf.assign_add(global_step, 1))  # EstimatorSpec connects subgraphs we built to the  # appropriate functionality.  return tf.estimator.EstimatorSpec(      mode=mode,      predictions=y,      loss=loss,      train_op=train)estimator = tf.estimator.Estimator(model_fn=model_fn)# define our data setsx_train = np.array([1., 2., 3., 4.])y_train = np.array([0., -1., -2., -3.])x_eval = np.array([2., 5., 8., 1.])y_eval = np.array([-1.01, -4.1, -7., 0.])input_fn = tf.estimator.inputs.numpy_input_fn(    {"x": x_train}, y_train, batch_size=4, num_epochs=None, shuffle=True)train_input_fn = tf.estimator.inputs.numpy_input_fn(    {"x": x_train}, y_train, batch_size=4, num_epochs=1000, shuffle=False)eval_input_fn = tf.estimator.inputs.numpy_input_fn(    {"x": x_eval}, y_eval, batch_size=4, num_epochs=1000, shuffle=False)# trainestimator.train(input_fn=input_fn, steps=1000)# Here we evaluate how well our model did.train_metrics = estimator.evaluate(input_fn=train_input_fn)eval_metrics = estimator.evaluate(input_fn=eval_input_fn)print("train metrics: %r"% train_metrics)print("eval metrics: %r"% eval_metrics)

运行结果:

train metrics: {'loss': 1.1786181e-11, 'global_step': 1000}
eval metrics: {'loss': 0.010100437, 'global_step': 1000}