tf.split (API r0.12 / r0.9)

来源:互联网 发布:linux打包成zip文件 编辑:程序博客网 时间:2024/06/06 02:53

 tf.split (API r0.12 / r0.9)

r0.12

1. tf.split(split_dim, num_split, value, name='split')

Splits a tensor into num_split tensors along one dimension.
Splits value along dimension split_dim into num_split smaller tensors. Requires that num_split evenly divide value.shape[split_dim].

For example:

# 'value' is a tensor with shape [5, 30]# Split 'value' into 3 tensors along dimension 1split0, split1, split2 = tf.split(1, 3, value)tf.shape(split0) ==> [5, 10]

Note: If you are splitting along an axis by the length of that axis, consider using unpack, e.g.

num_items = t.get_shape()[axis].value[tf.squeeze(s, [axis]) for s in tf.split(axis, num_items, t)]

can be rewritten as
tf.unpack(t, axis=axis)

Args:
split_dim: A 0-D int32 Tensor. The dimension along which to split. Must be in the range [0, rank(value)).
num_split: A Python integer. The number of ways to split.
value: The Tensor to split.
name: A name for the operation (optional).

Returns:
num_split Tensor objects resulting from splitting value.

2. example1 - r0.9

import tensorflow as tfimport numpy as npx = tf.constant([[0,  1,  2,  3,  4,  5],                 [6,  7,  8,  9,  10, 11],                 [12, 13, 14, 15, 16, 17],                 [18, 19, 20, 21, 22, 23]], dtype=np.float32)y = tf.split(0, 2, x)z = tf.split(1, 3, x)with tf.Session() as sess:    input_x = sess.run(x)    print("type(input_x):")    print(type(input_x))    print("input_x.shape:")    print(input_x.shape)    print('\n')    output_y = sess.run(y)    print("output_y:")    print(output_y)    print("output_y[0-1]:")    for step in range(2):        print("output_y[%d]:"%(step))        print(output_y[step])        print("type(output_y[%d]):%s"%(step, type(output_y[step])))    print('\n')    output_z = sess.run(z)    print("output_z:")    print(output_z)    print("output_z[0-2]:")    for step in range(3):        print("output_z[%d]:"%(step))        print(output_z[step])        print("type(output_z[%d]):%s"%(step, type(output_z[step])))    print('\n')

output:

type(input_x):<type 'numpy.ndarray'>input_x.shape:(4, 6)output_y:[array([[  0.,   1.,   2.,   3.,   4.,   5.],       [  6.,   7.,   8.,   9.,  10.,  11.]], dtype=float32), array([[ 12.,  13.,  14.,  15.,  16.,  17.],       [ 18.,  19.,  20.,  21.,  22.,  23.]], dtype=float32)]output_y[0-1]:output_y[0]:[[  0.   1.   2.   3.   4.   5.] [  6.   7.   8.   9.  10.  11.]]type(output_y[0]):<type 'numpy.ndarray'>output_y[1]:[[ 12.  13.  14.  15.  16.  17.] [ 18.  19.  20.  21.  22.  23.]]type(output_y[1]):<type 'numpy.ndarray'>output_z:[array([[  0.,   1.],       [  6.,   7.],       [ 12.,  13.],       [ 18.,  19.]], dtype=float32), array([[  2.,   3.],       [  8.,   9.],       [ 14.,  15.],       [ 20.,  21.]], dtype=float32), array([[  4.,   5.],       [ 10.,  11.],       [ 16.,  17.],       [ 22.,  23.]], dtype=float32)]output_z[0-2]:output_z[0]:[[  0.   1.] [  6.   7.] [ 12.  13.] [ 18.  19.]]type(output_z[0]):<type 'numpy.ndarray'>output_z[1]:[[  2.   3.] [  8.   9.] [ 14.  15.] [ 20.  21.]]type(output_z[1]):<type 'numpy.ndarray'>output_z[2]:[[  4.   5.] [ 10.  11.] [ 16.  17.] [ 22.  23.]]type(output_z[2]):<type 'numpy.ndarray'>Process finished with exit code 0

3. example2 - r0.9

import tensorflow as tfimport numpy as npbatch_size = 1num_steps = 6num_input = 2# x_anchor shape: (batch_size, n_steps, n_input)x_anchor = tf.constant([[[0, 1],                         [2, 3],                         [4, 5],                         [6, 7],                         [8, 9],                         [10, 11]]], dtype=np.float32)# permute num_steps and batch_sizey_anchor = tf.transpose(x_anchor, perm=[1, 0, 2])# (num_steps*batch_size, num_input)y_reshape = tf.reshape(y_anchor, [num_steps * batch_size, num_input])# Split data because rnn cell needs a list of inputs for the RNN inner loop# n_steps * (batch_size, num_input)# tf.__version__ == '1.3.0'# y_split = tf.split(y_reshape, num_steps, 0)# tf.__version__ == '0.9.0'y_split = tf.split(0, num_steps, y_reshape)with tf.Session() as sess:    input_anchor = sess.run(x_anchor)    print("type(input_anchor):")    print(type(input_anchor))    print("input_anchor.shape:")    print(input_anchor.shape)    print('\n')    output_anchor = sess.run(y_anchor)    print("type(output_anchor):")    print(type(output_anchor))    print("output_anchor.shape:")    print(output_anchor.shape)    print("output_anchor:")    print(output_anchor)    print('\n')    output_reshape = sess.run(y_reshape)    print("type(output_reshape):")    print(type(output_reshape))    print("output_reshape.shape:")    print(output_reshape.shape)    print("output_reshape:")    print(output_reshape)    print('\n')    output_split = sess.run(y_split)    print("type(output_split):")    print(type(output_split))    print("output_split:")    print(output_split)    print('\n')    print("output_split[0-5]:")    for step in range(num_steps):        print("output_split[%d]:"%(step))        print(output_split[step])        print("type(output_split[%d]):%s"%(step, type(output_split[step])))    print('\n')    print("output_split[0-5]:")    for step in range(num_steps):        print("[output_split[%d]]:" %(step))        print([output_split[step]])        print("type([output_split[%d]]):%s"%(step, type([output_split[step]])))

output:

type(input_anchor):<type 'numpy.ndarray'>input_anchor.shape:(1, 6, 2)type(output_anchor):<type 'numpy.ndarray'>output_anchor.shape:(6, 1, 2)output_anchor:[[[  0.   1.]] [[  2.   3.]] [[  4.   5.]] [[  6.   7.]] [[  8.   9.]] [[ 10.  11.]]]type(output_reshape):<type 'numpy.ndarray'>output_reshape.shape:(6, 2)output_reshape:[[  0.   1.] [  2.   3.] [  4.   5.] [  6.   7.] [  8.   9.] [ 10.  11.]]type(output_split):<type 'list'>output_split:[array([[ 0.,  1.]], dtype=float32), array([[ 2.,  3.]], dtype=float32), array([[ 4.,  5.]], dtype=float32), array([[ 6.,  7.]], dtype=float32), array([[ 8.,  9.]], dtype=float32), array([[ 10.,  11.]], dtype=float32)]output_split[0-5]:output_split[0]:[[ 0.  1.]]type(output_split[0]):<type 'numpy.ndarray'>output_split[1]:[[ 2.  3.]]type(output_split[1]):<type 'numpy.ndarray'>output_split[2]:[[ 4.  5.]]type(output_split[2]):<type 'numpy.ndarray'>output_split[3]:[[ 6.  7.]]type(output_split[3]):<type 'numpy.ndarray'>output_split[4]:[[ 8.  9.]]type(output_split[4]):<type 'numpy.ndarray'>output_split[5]:[[ 10.  11.]]type(output_split[5]):<type 'numpy.ndarray'>output_split[0-5]:[output_split[0]]:[array([[ 0.,  1.]], dtype=float32)]type([output_split[0]]):<type 'list'>[output_split[1]]:[array([[ 2.,  3.]], dtype=float32)]type([output_split[1]]):<type 'list'>[output_split[2]]:[array([[ 4.,  5.]], dtype=float32)]type([output_split[2]]):<type 'list'>[output_split[3]]:[array([[ 6.,  7.]], dtype=float32)]type([output_split[3]]):<type 'list'>[output_split[4]]:[array([[ 8.,  9.]], dtype=float32)]type([output_split[4]]):<type 'list'>[output_split[5]]:[array([[ 10.,  11.]], dtype=float32)]type([output_split[5]]):<type 'list'>Process finished with exit code 0

4. example3 - r0.9

import tensorflow as tfimport numpy as npbatch_size = 2num_steps = 6num_input = 2# x_anchor shape: (batch_size, n_steps, n_input)x_anchor = tf.constant([[[0,   1],                         [2,   3],                         [4,   5],                         [6,   7],                         [8,   9],                         [10, 11]],                        [[12, 13],                         [14, 15],                         [16, 17],                         [18, 19],                         [20, 21],                         [22, 23]]], dtype=np.float32)# permute num_steps and batch_sizey_anchor = tf.transpose(x_anchor, perm=[1, 0, 2])# (num_steps*batch_size, num_input)y_reshape = tf.reshape(y_anchor, [num_steps * batch_size, num_input])# Split data because rnn cell needs a list of inputs for the RNN inner loop# n_steps * (batch_size, num_input)# tf.__version__ == '1.3.0'# y_split = tf.split(y_reshape, num_steps, 0)# tf.__version__ == '0.9.0'y_split = tf.split(0, num_steps, y_reshape)with tf.Session() as sess:    input_anchor = sess.run(x_anchor)    print("type(input_anchor):")    print(type(input_anchor))    print("input_anchor.shape:")    print(input_anchor.shape)    print('\n')    output_anchor = sess.run(y_anchor)    print("type(output_anchor):")    print(type(output_anchor))    print("output_anchor.shape:")    print(output_anchor.shape)    print("output_anchor:")    print(output_anchor)    print('\n')    output_reshape = sess.run(y_reshape)    print("type(output_reshape):")    print(type(output_reshape))    print("output_reshape.shape:")    print(output_reshape.shape)    print("output_reshape:")    print(output_reshape)    print('\n')    output_split = sess.run(y_split)    print("type(output_split):")    print(type(output_split))    print("output_split:")    print(output_split)    print('\n')    print("output_split[0-5]:")    for step in range(num_steps):        print("output_split[%d]:"%(step))        print(output_split[step])        print("type(output_split[%d]):%s"%(step, type(output_split[step])))    print('\n')    print("output_split[0-5]:")    for step in range(num_steps):        print("[output_split[%d]]:" %(step))        print([output_split[step]])        print("type([output_split[%d]]):%s"%(step, type([output_split[step]])))

output:

type(input_anchor):<type 'numpy.ndarray'>input_anchor.shape:(2, 6, 2)type(output_anchor):<type 'numpy.ndarray'>output_anchor.shape:(6, 2, 2)output_anchor:[[[  0.   1.]  [ 12.  13.]] [[  2.   3.]  [ 14.  15.]] [[  4.   5.]  [ 16.  17.]] [[  6.   7.]  [ 18.  19.]] [[  8.   9.]  [ 20.  21.]] [[ 10.  11.]  [ 22.  23.]]]type(output_reshape):<type 'numpy.ndarray'>output_reshape.shape:(12, 2)output_reshape:[[  0.   1.] [ 12.  13.] [  2.   3.] [ 14.  15.] [  4.   5.] [ 16.  17.] [  6.   7.] [ 18.  19.] [  8.   9.] [ 20.  21.] [ 10.  11.] [ 22.  23.]]type(output_split):<type 'list'>output_split:[array([[  0.,   1.],       [ 12.,  13.]], dtype=float32), array([[  2.,   3.],       [ 14.,  15.]], dtype=float32), array([[  4.,   5.],       [ 16.,  17.]], dtype=float32), array([[  6.,   7.],       [ 18.,  19.]], dtype=float32), array([[  8.,   9.],       [ 20.,  21.]], dtype=float32), array([[ 10.,  11.],       [ 22.,  23.]], dtype=float32)]output_split[0-5]:output_split[0]:[[  0.   1.] [ 12.  13.]]type(output_split[0]):<type 'numpy.ndarray'>output_split[1]:[[  2.   3.] [ 14.  15.]]type(output_split[1]):<type 'numpy.ndarray'>output_split[2]:[[  4.   5.] [ 16.  17.]]type(output_split[2]):<type 'numpy.ndarray'>output_split[3]:[[  6.   7.] [ 18.  19.]]type(output_split[3]):<type 'numpy.ndarray'>output_split[4]:[[  8.   9.] [ 20.  21.]]type(output_split[4]):<type 'numpy.ndarray'>output_split[5]:[[ 10.  11.] [ 22.  23.]]type(output_split[5]):<type 'numpy.ndarray'>output_split[0-5]:[output_split[0]]:[array([[  0.,   1.],       [ 12.,  13.]], dtype=float32)]type([output_split[0]]):<type 'list'>[output_split[1]]:[array([[  2.,   3.],       [ 14.,  15.]], dtype=float32)]type([output_split[1]]):<type 'list'>[output_split[2]]:[array([[  4.,   5.],       [ 16.,  17.]], dtype=float32)]type([output_split[2]]):<type 'list'>[output_split[3]]:[array([[  6.,   7.],       [ 18.,  19.]], dtype=float32)]type([output_split[3]]):<type 'list'>[output_split[4]]:[array([[  8.,   9.],       [ 20.,  21.]], dtype=float32)]type([output_split[4]]):<type 'list'>[output_split[5]]:[array([[ 10.,  11.],       [ 22.,  23.]], dtype=float32)]type([output_split[5]]):<type 'list'>Process finished with exit code 0

5.