tf.concat (API r0.12 / r0.9)

来源:互联网 发布:哪项不是数据定义语言 编辑:程序博客网 时间:2024/06/05 20:54

tf.concat (API r0.12 / r0.9)

 r0.12

1. tf.concat(concat_dim, values, name='concat')

Concatenates tensors along one dimension.
Concatenates the list of tensors values along dimension concat_dim. If values[i].shape = [D0, D1, ... Dconcat_dim(i), ...Dn], the concatenated result has shape

[D0, D1, ... Rconcat_dim, ...Dn]

where

Rconcat_dim = sum(Dconcat_dim(i))

That is, the data from the input tensors is joined along the concat_dim dimension.
The number of dimensions of the input tensors must match, and all dimensions except concat_dim must be equal.

For example:

t1 = [[1, 2, 3], [4, 5, 6]]t2 = [[7, 8, 9], [10, 11, 12]]tf.concat(0, [t1, t2]) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]tf.concat(1, [t1, t2]) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]# tensor t3 with shape [2, 3]# tensor t4 with shape [2, 3]tf.shape(tf.concat(0, [t3, t4])) ==> [4, 3]tf.shape(tf.concat(1, [t3, t4])) ==> [2, 6]

Note: If you are concatenating along a new axis consider using pack. E.g.

tf.concat(axis, [tf.expand_dims(t, axis) for t in tensors])

can be rewritten as

tf.pack(tensors, axis=axis)

Args:
concat_dim: 0-D int32 Tensor. Dimension along which to concatenate.
values: A list of Tensor objects or a single Tensor.
name: A name for the operation (optional).

Returns:
A Tensor resulting from concatenation of the input tensors.


2. example1 - r0.9.0

import tensorflow as tfimport numpy as npt1 = tf.constant([[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]], dtype=np.float32)t2 = tf.constant([[[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]]], dtype=np.float32)matrix0 = tf.concat(0, [t1, t2])matrix1 = tf.concat(1, [t1, t2])matrix2 = tf.concat(2, [t1, t2])ops_shape0 = tf.shape(tf.concat(0, [t1, t2]))ops_shape1 = tf.shape(tf.concat(1, [t1, t2]))ops_shape2 = tf.shape(tf.concat(2, [t1, t2]))with tf.Session() as sess:    input_t1 = sess.run(t1)    print("input_t1.shape:")    print(input_t1.shape)    print('\n')    input_t2 = sess.run(t2)    print("input_t2.shape:")    print(input_t2.shape)    print('\n')    output_t1 = sess.run(matrix0)    print("output_t1.shape:")    print(output_t1.shape)    print("output_t1:")    print(output_t1)    print('\n')    output_t2 = sess.run(matrix1)    print("output_t2.shape:")    print(output_t2.shape)    print("output_t2:")    print(output_t2)    print('\n')    output_t3 = sess.run(matrix2)    print("output_t3.shape:")    print(output_t3.shape)    print("output_t3:")    print(output_t3)    print('\n')    output_shape0 = sess.run(ops_shape0)    output_shape1 = sess.run(ops_shape1)    output_shape2 = sess.run(ops_shape2)    print("output_shape0:")    print(output_shape0)    print("output_shape1:")    print(output_shape1)    print("output_shape2:")    print(output_shape2)

output:

input_t1.shape:(1, 4, 3)input_t2.shape:(1, 4, 3)output_t1.shape:(2, 4, 3)output_t1:[[[  0.   1.   2.]  [  3.   4.   5.]  [  6.   7.   8.]  [  9.  10.  11.]] [[ 12.  13.  14.]  [ 15.  16.  17.]  [ 18.  19.  20.]  [ 21.  22.  23.]]]output_t2.shape:(1, 8, 3)output_t2:[[[  0.   1.   2.]  [  3.   4.   5.]  [  6.   7.   8.]  [  9.  10.  11.]  [ 12.  13.  14.]  [ 15.  16.  17.]  [ 18.  19.  20.]  [ 21.  22.  23.]]]output_t3.shape:(1, 4, 6)output_t3:[[[  0.   1.   2.  12.  13.  14.]  [  3.   4.   5.  15.  16.  17.]  [  6.   7.   8.  18.  19.  20.]  [  9.  10.  11.  21.  22.  23.]]]output_shape0:[2 4 3]output_shape1:[1 8 3]output_shape2:[1 4 6]Process finished with exit code 0

3. example2 - r0.9.0

import tensorflow as tfimport numpy as npt1 = tf.constant([[0, 1, 2], [3, 4, 5]], dtype=np.float32)t2 = tf.constant([[6, 7, 8], [9, 10, 11]], dtype=np.float32)matrix0 = tf.concat(0, [t1, t2])matrix1 = tf.concat(1, [t1, t2])ops_shape0 = tf.shape(tf.concat(0, [t1, t2]))ops_shape1 = tf.shape(tf.concat(1, [t1, t2]))with tf.Session() as sess:    input_t1 = sess.run(t1)    print("input_t1.shape:")    print(input_t1.shape)    print('\n')    input_t2 = sess.run(t2)    print("input_t2.shape:")    print(input_t2.shape)    print('\n')    output_t1 = sess.run(matrix0)    print("output_t1.shape:")    print(output_t1.shape)    print("output_t1:")    print(output_t1)    print('\n')    output_t2 = sess.run(matrix1)    print("output_t2.shape:")    print(output_t2.shape)    print("output_t2:")    print(output_t2)    print('\n')    output_shape0 = sess.run(ops_shape0)    output_shape1 = sess.run(ops_shape1)    print("output_shape0:")    print(output_shape0)    print("output_shape1:")    print(output_shape1)

output:

input_t1.shape:(2, 3)input_t2.shape:(2, 3)output_t1.shape:(4, 3)output_t1:[[  0.   1.   2.] [  3.   4.   5.] [  6.   7.   8.] [  9.  10.  11.]]output_t2.shape:(2, 6)output_t2:[[  0.   1.   2.   6.   7.   8.] [  3.   4.   5.   9.  10.  11.]]output_shape0:[4 3]output_shape1:[2 6]Process finished with exit code 0


4.

原创粉丝点击