tensorflow中获取shape的方法比较

来源:互联网 发布:ios软件源 编辑:程序博客网 时间:2024/05/20 04:14

tf.shape(xxx) 和 xxx.get_shape()比较

  1. 相同点:都可以得到tensor xxx 的尺寸
  2. 不同点:tf.shape(xxx)中xxx数据的类型可以是tensor,list,array;而xxx.get_shape()中的xxx的数据类型必须是tensor,且返回的是一个tuple.可以通过xxx.get_shape().as_list()得到一个list。

例如:

x= tf.truncated_normal([32, 32, 3], dtype=tf.float32)print(tf.shape(x))print(x.get_shape())print(x.get_shape().as_list())
  • 1
  • 2
  • 3
  • 4
  • 5

输出:

Tensor("Shape:0", shape=(3,), dtype=int32)(32, 32, 3)[32, 32, 3]
  • 1
  • 2
  • 3

注意:dtype=int32是tf.shape()这个op的输出类型,默认为tf.int32。

原创粉丝点击