tensorflow 中tf.gather(params, indices, validate_indices=None, name=None) 函数讲解

来源:互联网 发布:ios上传图片数组 编辑:程序博客网 时间:2024/05/21 20:30

tf.gather

tf.gather(params, indices, validate_indices=None, name=None, axis=0)

params 表示你输入的张量,indices表示你想要params张量中切片的维度,所以这个函数就是挑选出params中indices对应的数。

举例子

x = tf.constant(np.arange(8).reshape((2,2,2)))y = tf.gather(x,[0])sess  = tf.Session()print(sess.run(y))print('---------')print(sess.run(x))
[[[0 1]  [2 3]]]---------[[[0 1]  [2 3]] [[4 5]  [6 7]]]

就相当把x矩阵的第一个切片取出来了
想要其他轴切片,可以设置axis这个参数

阅读全文
0 0