tf.shape()与tf.get_shape()

来源:互联网 发布:双十一淘宝店铺宣言 编辑:程序博客网 时间:2024/05/22 02:28
import tensorflow as tfimport numpy as npx=tf.constant([[1,2,3],[4,5,6]])y =[[1,2,3],[4,5,6]]z= np.arange(24).reshape([2,3,4])sess = tf.Session()x_shape= tf.shape(x)y_shape=tf.shape(y)z_shape=tf.shape(z)print(x_shape)print(y_shape)print(z_shape)print(sess.run(x_shape))print(sess.run(y_shape))print(sess.run(z_shape))
Tensor("Shape_12:0", shape=(2,), dtype=int32)Tensor("Shape_13:0", shape=(2,), dtype=int32)Tensor("Shape_14:0", shape=(3,), dtype=int32)[2 3][2 3][2 3 4]
print(x.get_shape)print(x.get_shape())print(x.get_shape().as_list())#print(sess.run(x.get_shape))
<bound method Tensor.get_shape of <tf.Tensor 'Const_4:0' shape=(2, 3) dtype=int32>>(2, 3)[2, 3]

不能使用sess.run()是因为get_shape()返回的不是tensor,而是元组。

print(y.get_shape())
---------------------------------------------------------------------------AttributeError                            Traceback (most recent call last)<ipython-input-13-0ef76bb118ce> in <module>()----> 1 print(y.get_shape())AttributeError: 'list' object has no attribute 'get_shape'
print(z.get_shape())
---------------------------------------------------------------------------AttributeError                            Traceback (most recent call last)<ipython-input-14-07fd160b4a19> in <module>()----> 1 print(z.get_shape())AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'

tf.shape(a)和a.get_shape()比较

相同点:都能得到tensor a 的尺寸

不同点:tf.shape()中a数据类型可以是tensor,list,array。
a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组。

原创粉丝点击