tensorflow API:tf.split

来源:互联网 发布:数据库的基本语句 编辑:程序博客网 时间:2024/06/05 20:34

split(value, num_or_size_splits, axis=0, num=None, name=’split’)
Splits a tensor into sub tensors.

If `num_or_size_splits` is an integer type, `num_split`, then splits `value`along dimension `axis` into `num_split` smaller tensors.Requires that `num_split` evenly divides `value.shape[axis]`.

如果参数num_or_size_splits是整数,则把value切片为该整数个
If num_or_size_splits is not an integer type, it is presumed to be a Tensor size_splits, then splitsvalueinto len(size_splits) pieces.
否则给出的是期望在axis上切片下来维度list。
The shape of the i-th piece has the same size as the value except along dimension axis where the size is size_splits[i].
除了axis的维度不一样,切片过后其他维度保持不变。

例子:
# ‘value’ is a tensor with shape [5, 30]
# Split ‘value’ into 3 tensors with sizes [4, 15, 11] along dimension 1
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0) # [5, 4]
tf.shape(split1) # [5, 15]
tf.shape(split2) # [5, 11]
# Split ‘value’ into 3 tensors along dimension 1
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0) # [5, 10]
“`
我来写个例子:

tensor = [[1,2,3],         [4,5,6],         [7,8,9]]
with tf.Session() as sess:    """沿着1轴切片"""    tensor1,tensor2,tensor3 = tf.split(tensor,num_or_size_splits=3,axis=1)    print(tensor1.eval())    print('--------------')    """沿着0轴切片"""    tensor1,tensor2,tensor3 = tf.split(tensor,num_or_size_splits=3,axis=0)    print('--------------')    print(tensor1.eval())    """给出切片list"""    tensor1, tensor2 = tf.split(tensor,num_or_size_splits=[1,2],axis=0)    print('--------------')    print(tensor2.eval())    """关于num参数还不清楚,这里略去了"""

输出:

[[1] [4] [7]]----------------------------[[1 2 3]]--------------[[4 5 6] [7 8 9]]
原创粉丝点击