tensorflow(1) mnist_softmax.py

来源:互联网 发布:软件测试中的单元测试 编辑:程序博客网 时间:2024/06/10 13:39

标签(空格分隔):tensorflow


1. 数据读取

mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
input_data是一个模块,里面都是一些import语句,read_data_sets是tensorflow自带的专门用来读取mnist这个数据集的一个函数!

之前一直对数据读取不明白,终于找到了这两个函数

train_images = extract_images(f)train_labels = extract_labels(f, one_hot=one_hot)

输入为.gz压缩文件,输出为 numpy array

def extract_images(f):  """Extract the images into a 4D uint8 numpy array [index, y, x, depth].  Args:    f: A file object that can be passed into a gzip reader.  Returns:    data: A 4D uint8 numpy array [index, y, x, depth].
def extract_labels(f, one_hot=False, num_classes=10):  """Extract the labels into a 1D uint8 numpy array [index].  Args:    f: A file object that can be passed into a gzip reader.    one_hot: Does one hot encoding for the result.    num_classes: Number of classes for the one hot encoding.  Returns:    labels: a 1D uint8 numpy array.

2. placeholder

请注意,在sess.run的时候,placeholder必须用feed_dict赋值,

def placeholder(dtype, shape=None, name=None):  """Inserts a placeholder for a tensor that will be always fed.  **Important**: This tensor will produce an error if evaluated. Its value must  be fed using the `feed_dict` optional argument to `Session.run()`,  `Tensor.eval()`, or `Operation.run()`.  Args:    dtype: The type of elements in the tensor to be fed.    shape: The shape of the tensor to be fed (optional). If the shape is not specified, you can feed a tensor of any shape.    name: A name for the operation (optional).  Returns:    A `Tensor` that may be used as a handle for feeding a value, but not    evaluated directly.  """  return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)

For example:

  x = tf.placeholder(tf.float32, shape=(1024, 1024))  y = tf.matmul(x, x)  with tf.Session() as sess:    print(sess.run(y))  # ERROR: will fail because x was not fed.    rand_array = np.random.rand(1024, 1024)    print(sess.run(y, feed_dict={x: rand_array}))  # Will succeed.

3. cross_entropy

cross_entropy = tf.reduce_mean(      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

tf.reduce_mean
在指定的维度上计算一个张量的均值,如果不指定维度,求所有向量的均值
维度0是把最外面的一层括号去掉

Computes the mean of elements across dimensions of a tensor.  # 'x' is [[1., 1.]  #         [2., 2.]]  tf.reduce_mean(x) ==> 1.5  tf.reduce_mean(x, 0) ==> [1.5, 1.5]  tf.reduce_mean(x, 1) ==> [1.,  2.]

计算logits(预测值)和label(真值)之间的 softmax交叉熵
返回一个一维向量,向量长度和logits长度一样,为batch_Size的长度。内容为softmax交叉熵

def softmax_cross_entropy_with_logits(_sentinel=None,  labels=None, logits=None, dim=-1, name=None): Computes softmax cross entropy between `logits` and `labels`.   Args:    _sentinel: Used to prevent positional parameters. Internal, do not use.    labels: Each row `labels[i]` must be a valid probability distribution.    logits: Unscaled log probabilities.    dim: The class dimension. Defaulted to -1 which is the last dimension.    name: A name for the operation (optional).  Returns:    A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the    softmax cross entropy loss.

4.train_step

  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

(1)GradientDescentOptimizer是一个模板类,传入参数0.5为学习率,init构造函数即能创建一个梯度下降的优化器对象。 它的父类是Optimizer

class GradientDescentOptimizer(optimizer.Optimizer):  """Optimizer that implements the gradient descent algorithm.  """  def __init__(self, learning_rate, use_locking=False, name="GradientDescent"):    """Construct a new gradient descent optimizer.

(2)minimize
最小化loss损失函数,(在本例中是最小化交叉熵,)是Optimizer类的成员函数

loss: A Tensor containing the value to minimize.
分为compute_decent和apply_decent两部分

返回更新var_list中所有变量的一个操作

 def minimize(self, loss, global_step=None, var_list=None,               gate_gradients=GATE_OP, aggregation_method=None,               colocate_gradients_with_ops=False, name=None,               grad_loss=None):    """Add operations to minimize `loss` by updating `var_list`.    Args:    loss: A `Tensor` containing the value to minimize.    global_step: Optional `Variable` to increment by one after the variables have been updated.    var_list: Optional list or tuple of `Variable` objects to update to minimize `loss`.  Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.    ...    Returns:      An Operation that updates the variables in `var_list`.  If `global_step`      was not `None`, that operation also increments `global_step`.

5. sess,train

  sess = tf.InteractiveSession()  tf.global_variables_initializer().run()  #初始化所有变量  # Train  for _ in range(1000):    batch_xs, batch_ys = mnist.train.next_batch(100)    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

sess = tf.InteractiveSession()创建一个新交互的sess sess = tf.Session()
1000次循环,每次加载100个图片
重点是这个sess.run函数

  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):    """Runs operations and evaluates tensors in `fetches`.    The optional `feed_dict` argument allows the caller to override    the value of tensors in the graph.

fetches是图,或者元素是图的list,touple …
feed_dict允许重写图中的变量

6. test

  # Test trained model  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  print(sess.run(accuracy, feed_dict={x: mnist.test.images,                                      y_: mnist.test.labels}))

accuracy也是一个图,是各种op的叠加。

tf.argmax(y,1),1是指维度,默认维度为0,返回指定维度上的最大值
tf.cast Casts a tensor to a new type.类型转换
tf.reduce_mean求均值,若没指定维度,就是求所有值的均值

7. tf.app.run

if __name__ == '__main__':  parser = argparse.ArgumentParser()  parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',                      help='Directory for storing input data')  FLAGS, unparsed = parser.parse_known_args()  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

parser = argparse.ArgumentParser() argparse是模块argparse.py的名字,ArgumentParser()是类,创建一个类对象。

feed_dict前面必须有placeholder对应?
[sys.argv[0]] + unparsed是什么意思
数据集是如何分割的

tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值
比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,我们可以用 tf.equal 来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)。

原创粉丝点击