tensorflow中get_shape函数的应用

来源:互联网 发布:mysql oxc000007b 编辑:程序博客网 时间:2024/05/29 05:11

get_shape函数主要用于获取一个张量的维度,并且输出张量 每个维度上面的值,如果是二维矩阵,也就是输出行和列的值,使用非常方便。

例如:

import tensorflow as tf;  with tf.Session() as sess:A = tf.random_normal(shape=[3,4])
print A.get_shape()print A.get_shape

输出:

(3, 4)
<bound method Tensor.get_shape of <tf.Tensor 'random_normal:0' shape=(3, 4) dtype=float32>>

注意:第一个输出是一个元祖,就是数值,而第二输出就是一个张量的对象,里面包含更多的东西,在不同的情况下,使用不同的方式。如果你需要输出某一个维度上面的值那就用下面的这种方式就好了。

A.get_shape()[0]
这就表示第一个维度。

原创粉丝点击