tf.transpose()函数

来源:互联网 发布:微信回调域名不备案 编辑:程序博客网 时间:2024/06/04 19:14

tensorflow里面许多针对数组操作的函数,官方文档又看了没啥卵用,网上帖子直接copy官方文档而不解释,只能自己写个程序测试理解,以3个维度的tensor进行理解

tf.transpose()作为数组的转置函数,原型如下:

def transpose(a, perm=None, name="transpose"):  """Transposes `a`. Permutes the dimensions according to `perm`.

a:是传入的数组

perm:控制转置的操作,以perm = [0,1,2] 3个维度的数组为例, 0--代表的是最外层的一维, 1--代表外向内数第二维, 2--代表最内层的一维,这种perm是默认的值.现在以如下输入数组来理解这个函数和参数perm

input_x = [    [        [1, 2, 3, 4],        [5, 6, 7, 8],        [9, 10, 11, 12]    ],    [        [13, 14, 15, 16],        [17, 18, 19, 20],        [21, 22, 23, 24]    ]]

input_x 是一个 2x3x4的一个tensor, 假设perm = [1,0,2], 就是将最外2层转置,得到tensor应该是  3x2x4的一个张量,将input_x抽象化,不管第3维度

[

     [

          A,

          B,

          C

     ],

     [

          D,

          E,

          F,

     ]

]

变成2x3的tensor,类似于2x3的数组

[

     A  B  C

     D  E  F

]

转置变成 3x2的数组

[

    A  D

    B  E

    C  F

]

再将A-F换成具体的值,最终得到的张量是

[

  [

     [ 1  2  3  4]
     [13 14 15 16]

 ]
 [

    [ 5  6  7  8]
    [17 18 19 20]

  ]
  [

     [ 9 10 11 12]
     [21 22 23 24]

  ]

]

这就可以看出perm前两列交换的作用

如果 perm=[0,2,1]说明要交换内层里面的两个维度,从原来的2x3x4变成2x4x3的张量,就不抽象化了,结果就是

[

  [

      [ 1  5  9]
      [ 2  6 10]
      [ 3  7 11]
      [ 4  8 12]

  ]

  [

     [13 17 21]
     [14 18 22]
     [15 19 23]
     [16 20 24]

  ]

]

下面贴出我的代码:

import tensorflow as tfinput_x = [    [        [1, 2, 3, 4],        [5, 6, 7, 8],        [9, 10, 11, 12]    ],    [        [13, 14, 15, 16],        [17, 18, 19, 20],        [21, 22, 23, 24]    ]]result = tf.transpose(input_x, perm=[0, 2, 1])with tf.Session() as sess:    print(sess.run(result))

注意,使用print(result)只会打印tensor的name  shape  dtype信息

Tensor("transpose:0", shape=(2, 4, 3), dtype=int32)

想要打出数组的形式,使用session

result = tf.transpose(input_x, perm=[0, 2, 1])print(result)with tf.Session() as sess:    print(sess.run(result))








原创粉丝点击