tf.reduce_sum (API r1.3)

来源:互联网 发布:js ajax post传递参数 编辑:程序博客网 时间:2024/05/29 03:43

tf.reduce_sum (API r1.3)

1. tf.reduce_sum

reduce_sum(    input_tensor,    axis=None,    keep_dims=False,    name=None,    reduction_indices=None)

Defined in tensorflow/python/ops/math_ops.py.
See the guide: Math > Reduction

Computes the sum of elements across dimensions of a tensor.

Reduces input_tensor along the dimensions given in axis. Unless keep_dims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keep_dims is true, the reduced dimensions are retained with length 1.

If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.

For example:


# 'x' is [[1, 1, 1]#         [1, 1, 1]]tf.reduce_sum(x) ==> 6tf.reduce_sum(x, 0) ==> [2, 2, 2]tf.reduce_sum(x, 1) ==> [3, 3]tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]tf.reduce_sum(x, [0, 1]) ==> 6

Args:
input_tensor: The tensor to reduce. Should have numeric type.axis: The dimensions to reduce. If None (the default), reduces all dimensions.keep_dims: If true, retains reduced dimensions with length 1.name: A name for the operation (optional).reduction_indices: The old (deprecated) name for axis.

Returns:
The reduced tensor.




numpy compatibility:
Equivalent to np.sum

2. example 1

import tensorflow as tfimport numpy as npt1 = tf.constant([[0, 1, 2], [3, 4, 5]], dtype=np.float32)rs0 = tf.reduce_sum(t1)rs1 = tf.reduce_sum(t1, 0)rs2 = tf.reduce_sum(t1, 1)rs3 = tf.reduce_sum(t1, 1, keep_dims=True)rs4 = tf.reduce_sum(t1, [0, 1])with tf.Session() as sess:    input_t1 = sess.run(t1)    print("input_t1.shape:")    print(input_t1.shape)    print("input_t1:")    print(input_t1)    print('\n')    output0 = sess.run(rs0)    print("output0.shape:")    print(output0.shape)    print("output0:")    print(output0)    print('\n')    output1 = sess.run(rs1)    print("output1.shape:")    print(output1.shape)    print("output1:")    print(output1)    print('\n')    output2 = sess.run(rs2)    print("output2.shape:")    print(output2.shape)    print("output2:")    print(output2)    print('\n')    output3 = sess.run(rs3)    print("output3.shape:")    print(output3.shape)    print("output3:")    print(output3)    print('\n')    output4 = sess.run(rs4)    print("output4.shape:")    print(output4.shape)    print("output4:")    print(output4)

output:

input_t1.shape:(2, 3)input_t1:[[ 0.  1.  2.] [ 3.  4.  5.]]output0.shape:()output0:15.0output1.shape:(3,)output1:[ 3.  5.  7.]output2.shape:(2,)output2:[  3.  12.]output3.shape:(2, 1)output3:[[  3.] [ 12.]]output4.shape:()output4:15.0Process finished with exit code 0

3.








原创粉丝点击