tensorflow(1) mnist_softmax.py
来源:互联网 发布:软件测试中的单元测试 编辑:程序博客网 时间:2024/06/10 13:39
标签(空格分隔):tensorflow
1. 数据读取
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
input_data是一个模块,里面都是一些import语句,read_data_sets是tensorflow自带的专门用来读取mnist这个数据集的一个函数!
之前一直对数据读取不明白,终于找到了这两个函数
train_images = extract_images(f)train_labels = extract_labels(f, one_hot=one_hot)
输入为.gz压缩文件,输出为 numpy array
def extract_images(f): """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. Args: f: A file object that can be passed into a gzip reader. Returns: data: A 4D uint8 numpy array [index, y, x, depth].
def extract_labels(f, one_hot=False, num_classes=10): """Extract the labels into a 1D uint8 numpy array [index]. Args: f: A file object that can be passed into a gzip reader. one_hot: Does one hot encoding for the result. num_classes: Number of classes for the one hot encoding. Returns: labels: a 1D uint8 numpy array.
2. placeholder
请注意,在sess.run的时候,placeholder必须用feed_dict赋值,
def placeholder(dtype, shape=None, name=None): """Inserts a placeholder for a tensor that will be always fed. **Important**: This tensor will produce an error if evaluated. Its value must be fed using the `feed_dict` optional argument to `Session.run()`, `Tensor.eval()`, or `Operation.run()`. Args: dtype: The type of elements in the tensor to be fed. shape: The shape of the tensor to be fed (optional). If the shape is not specified, you can feed a tensor of any shape. name: A name for the operation (optional). Returns: A `Tensor` that may be used as a handle for feeding a value, but not evaluated directly. """ return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
For example:
x = tf.placeholder(tf.float32, shape=(1024, 1024)) y = tf.matmul(x, x) with tf.Session() as sess: print(sess.run(y)) # ERROR: will fail because x was not fed. rand_array = np.random.rand(1024, 1024) print(sess.run(y, feed_dict={x: rand_array})) # Will succeed.
3. cross_entropy
cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
tf.reduce_mean
在指定的维度上计算一个张量的均值,如果不指定维度,求所有向量的均值
维度0是把最外面的一层括号去掉
Computes the mean of elements across dimensions of a tensor. # 'x' is [[1., 1.] # [2., 2.]] tf.reduce_mean(x) ==> 1.5 tf.reduce_mean(x, 0) ==> [1.5, 1.5] tf.reduce_mean(x, 1) ==> [1., 2.]
计算logits(预测值)和label(真值)之间的 softmax交叉熵
返回一个一维向量,向量长度和logits长度一样,为batch_Size的长度。内容为softmax交叉熵
def softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, dim=-1, name=None): Computes softmax cross entropy between `logits` and `labels`. Args: _sentinel: Used to prevent positional parameters. Internal, do not use. labels: Each row `labels[i]` must be a valid probability distribution. logits: Unscaled log probabilities. dim: The class dimension. Defaulted to -1 which is the last dimension. name: A name for the operation (optional). Returns: A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the softmax cross entropy loss.
4.train_step
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
(1)GradientDescentOptimizer是一个模板类,传入参数0.5为学习率,init构造函数即能创建一个梯度下降的优化器对象。 它的父类是Optimizer
class GradientDescentOptimizer(optimizer.Optimizer): """Optimizer that implements the gradient descent algorithm. """ def __init__(self, learning_rate, use_locking=False, name="GradientDescent"): """Construct a new gradient descent optimizer.
(2)minimize
最小化loss损失函数,(在本例中是最小化交叉熵,)是Optimizer类的成员函数
loss: A Tensor
containing the value to minimize.
分为compute_decent和apply_decent两部分
返回更新var_list中所有变量的一个操作
def minimize(self, loss, global_step=None, var_list=None, gate_gradients=GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, name=None, grad_loss=None): """Add operations to minimize `loss` by updating `var_list`. Args: loss: A `Tensor` containing the value to minimize. global_step: Optional `Variable` to increment by one after the variables have been updated. var_list: Optional list or tuple of `Variable` objects to update to minimize `loss`. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. ... Returns: An Operation that updates the variables in `var_list`. If `global_step` was not `None`, that operation also increments `global_step`.
5. sess,train
sess = tf.InteractiveSession() tf.global_variables_initializer().run() #初始化所有变量 # Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
sess = tf.InteractiveSession()
创建一个新交互的sess sess = tf.Session()
1000次循环,每次加载100个图片
重点是这个sess.run函数
def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """Runs operations and evaluates tensors in `fetches`. The optional `feed_dict` argument allows the caller to override the value of tensors in the graph.
fetches
是图,或者元素是图的list,touple … feed_dict
允许重写图中的变量
6. test
# Test trained model correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
accuracy也是一个图,是各种op的叠加。
tf.argmax(y,1)
,1是指维度,默认维度为0,返回指定维度上的最大值 tf.cast
Casts a tensor to a new type.类型转换 tf.reduce_mean
求均值,若没指定维度,就是求所有值的均值
7. tf.app.run
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', help='Directory for storing input data') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
parser = argparse.ArgumentParser()
argparse是模块argparse.py的名字,ArgumentParser()是类,创建一个类对象。
feed_dict前面必须有placeholder对应? [sys.argv[0]] + unparsed
是什么意思
数据集是如何分割的
tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值
比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,我们可以用 tf.equal 来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)。
- tensorflow(1) mnist_softmax.py
- 阅读mnist_softmax.py的一些基础练习。
- TensorFlow入门 fully_connected_feed.py
- TensorFlow入门 mint.py
- tensorflow mnist.py分析
- tensorflow学习fully_connected_feed.py
- Tensorflow-googlenetV3.py
- tensorflow/cifar10.py权重损失
- 【TensorFlow代码笔记】Cifar10_input.py
- Tensorflow-word2vec_simple.py的理解
- tensorflow学习笔记(十六):rnn_cell.py
- tensorflow mnist训练集input.py代码
- TensorFlow Time Benchmark for Googlenet (inception_v1_benchmark.py)
- tensorflow:fully_connected_feed.py代码详细中文注释
- 运行tensorflow基本程序mnist.py
- Tensorflow:fully_connected_feed.py运行报错
- Tensorflow学习笔记(8)——input_data.py解析
- [06]tensorflow源码例子mnist源码——mnist.py
- 摄像头采集程序linux下
- 解析c++中的引用和const引用,以及了解指针和引用的区别
- 安卓动态分析工具 Inspeckage
- Escape from Stones -DFS
- Service基本入门及AIDL跨进程通讯
- tensorflow(1) mnist_softmax.py
- 栈与队列的原理与实现
- PYTHON 开发 正则表达式
- Spring-boot整合Quartz,2、Bean配置
- Android ORM 框架之 greenDAO
- MyBatis基础
- LoRa简介
- jmeter之content-type导致参数为空
- 在eclipse中导入weka(3.6版本)的源代码包