tf.scatter_update tf.scatter_sub

来源:互联网 发布:安广网络宽带怎么样 编辑:程序博客网 时间:2024/05/01 07:18

tf.scatter_update

scatter_update(
ref,
indices,
updates,
use_locking=None,
name=None
)

scatter_sub(
ref,
indices,
updates,
use_locking=None,
name=None
)
在源码,函数的定义的位置在 tensorflow/Python/ops/gen_state_ops.py.
参数介绍:
ref: 原来的tensor;
indices: 原来tensor中要更新的索引值,同样也 tensor; 必须int
updates: 用于替代原来tensor的tensor值,注意,这个tensor和原来的tensor的最低维度要相同。

import tensorflow as tf import numpy as np with tf.Session() as sess1:    c = tf.Variable([[1,2,0],[2,3,4]], dtype=tf.float32, name='biases')     cc = tf.Variable([[1,2,0],[2,3,4]], dtype=tf.float32, name='biases1')     ccc = tf.Variable([0,1], dtype=tf.int32, name='biases2')     #对应label的centers-diff[0--]    centers = tf.scatter_sub(c,ccc,cc)    #centers = tf.scatter_sub(c,[0,1],cc)      #centers = tf.scatter_sub(c,[0,1],[[1,2,0],[2,3,4]])    #centers = tf.scatter_sub(c,[0,0,0],[[1,2,0],[2,3,4],[1,1,1]])    #即c[0]-[1,2,0] \ c[0]-[2,3,4]\ c[0]-[1,1,1],updates要减完:indices与updates元素个数相同    a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])      b = tf.scatter_update(a, [0, 1], [[1, 1, 0, 0], [1, 0, 4, 0]])      #b = tf.scatter_update(a, [0, 1,0], [[1, 1, 0, 0], [1, 0, 4, 0],[1, 1, 0, 1]])     init = tf.global_variables_initializer()     sess1.run(init)    print(sess1.run(centers))    print(sess1.run(b))[[ 0.  0.  0.] [ 0.  0.  0.]][[1 1 0 0] [1 0 4 0]][[-3. -4. -5.] [ 2.  3.  4.]][[1 1 0 1] [1 0 4 0]]