tf.transpose函数的用法讲解(多维情况,看似复杂,其实也简单)

来源:互联网 发布:己知周长求直径的公式 编辑:程序博客网 时间:2024/06/03 08:22

tf.transpose函数中文意思是转置,对于低维度的转置问题,很简单,不想讨论,直接转置就好(大家看下面文档,一看就懂)。

tf.transpose(a, perm=None, name='transpose')   Transposes a. Permutes the dimensions according to perm.  The returned tensor's dimension i will correspond to the input dimension perm[i]. If perm is not given, it is set to (n-1...0), where n is the rank of the input tensor. Hence by default, this operation performs a regular matrix transpose on 2-D input Tensors.  For example:  # 'x' is [[1 2 3]  #         [4 5 6]]  tf.transpose(x) ==> [[1 4]                       [2 5]                       [3 6]]  # Equivalently  tf.transpose(x perm=[1, 0]) ==> [[1 4]                                   [2 5]                                   [3 6]]  # 'perm' is more useful for n-dimensional tensors, for n > 2  # 'x' is   [[[1  2  3]  #            [4  5  6]]  #           [[7  8  9]  #            [10 11 12]]]  # Take the transpose of the matrices in dimension-0  tf.transpose(b, perm=[0, 2, 1]) ==> [[[1  4]                                        [2  5]                                        [3  6]]                                       [[7 10]                                        [8 11]                                        [9 12]]]  Args:   •a: A Tensor.  •perm: A permutation of the dimensions of a.  •name: A name for the operation (optional).  Returns:   A transposed Tensor.

本文主要讨论高维度的情况:

为了形象理解高维情况,这里以矩阵组合举例:

先定义下: 2 x (3*4)表示2个3*4的矩阵,(其实,它是个3维张量)。

x = [[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[21,22,23,24],[25,26,27,28],[29,30,31,32]]]

输出:

---------------
[[[ 1  2  3  4]
  [ 5  6  7  8]
  [ 9 10 11 12]]

 [[21 22 23 24]
  [25 26 27 28]
  [29 30 31 32]]]
---------------


重点来了:

tf.transpose的第二个参数perm=[0,1,2],0代表三维数组的高(即为二维数组的个数),1代表二维数组的行,2代表二维数组的列。
tf.transpose(x, perm=[1,0,2])代表将三位数组的高和行进行转置。

我们写个测试程序如下:

import tensorflow as tf#x = tf.constant([[1, 2 ,3],[4, 5, 6]])x = [[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[21,22,23,24],[25,26,27,28],[29,30,31,32]]]#a=tf.constant(x)a=tf.transpose(x, [0, 1, 2])b=tf.transpose(x, [0, 2, 1])c=tf.transpose(x, [1, 0, 2])d=tf.transpose(x, [1, 2, 0])e=tf.transpose(x, [2, 1, 0])f=tf.transpose(x, [2, 0, 1])# 'perm' is more useful for n-dimensional tensors, for n > 2# 'x' is   [[[1  2  3]#            [4  5  6]]#           [[7  8  9]#            [10 11 12]]]# Take the transpose of the matrices in dimension-0#tf.transpose(b, perm=[0, 2, 1])with tf.Session() as sess:    print ('---------------')    print (sess.run(a))    print ('---------------')    print (sess.run(b))    print ('---------------')    print (sess.run(c))    print ('---------------')    print (sess.run(d))    print ('---------------')    print (sess.run(e))    print ('---------------')    print (sess.run(f))    print ('---------------')

我们期待的结果是得到如下矩阵:

a: 2 x 3*4

b: 2 x 4*3

c: 3 x 2*4

d: 3 x 4*2

e: 4 x 3*2

f: 4 x 2*2

运行脚本,结果一致,如下:

---------------[[[ 1  2  3  4]  [ 5  6  7  8]  [ 9 10 11 12]] [[21 22 23 24]  [25 26 27 28]  [29 30 31 32]]]---------------[[[ 1  5  9]  [ 2  6 10]  [ 3  7 11]  [ 4  8 12]] [[21 25 29]  [22 26 30]  [23 27 31]  [24 28 32]]]---------------[[[ 1  2  3  4]  [21 22 23 24]] [[ 5  6  7  8]  [25 26 27 28]] [[ 9 10 11 12]  [29 30 31 32]]]---------------[[[ 1 21]  [ 2 22]  [ 3 23]  [ 4 24]] [[ 5 25]  [ 6 26]  [ 7 27]  [ 8 28]] [[ 9 29]  [10 30]  [11 31]  [12 32]]]---------------[[[ 1 21]  [ 5 25]  [ 9 29]] [[ 2 22]  [ 6 26]  [10 30]] [[ 3 23]  [ 7 27]  [11 31]] [[ 4 24]  [ 8 28]  [12 32]]]---------------[[[ 1  5  9]  [21 25 29]] [[ 2  6 10]  [22 26 30]] [[ 3  7 11]  [23 27 31]] [[ 4  8 12]  [24 28 32]]]---------------


最后,总结下:

[0, 1, 2]是正常显示,那么交换哪两个数字,就是把对应的输入张量的对应的维度对应交换即可。