Tensorflow argmax函数详解

来源:互联网 发布:程序员开发软件 编辑:程序博客网 时间:2024/05/19 06:19
def argmax(self, axis=None, fill_value=None, out=None):    返回沿着某个维度最大值的位置    Returns array of indices of the maximum values along the given axis.    Masked values are treated as if they had the value fill_value.    Parameters    ----------    axis : {None, integer}        If None, the index is into the flattened array, otherwise along        the specified axis    fill_value : {var}, optional        Value used to fill in the masked values.  If None, the output of        maximum_fill_value(self._data) is used instead.    out : {None, array}, optional        Array into which the result can be placed. Its type is preserved        and it must be of the right shape to hold the output.    Returns    -------    index_array : {integer_array}    Examples    --------    >>> a = np.arange(6).reshape(2,3)    >>> a.argmax()    5    >>> a.argmax(0)    array([1, 1, 1])    >>> a.argmax(1)    array([2, 2])    """    if fill_value is None:        fill_value = maximum_fill_value(self._data)    d = self.filled(fill_value).view(ndarray)    return d.argmax(axis, out=out)看下面的例子就更明白了:

tf.argmax | tf.argmin

tf.argmax(input=tensor,dimention=axis) 找到给定的张量tensor中在指定轴axis上的最大值/最小值的位置。

a=tf.get_variable(name='a',                  shape=[3,4],                  dtype=tf.float32,                  initializer=tf.random_uniform_initializer(minval=-1,maxval=1))b=tf.argmax(input=a,dimension=0)c=tf.argmax(input=a,dimension=1)sess = tf.InteractiveSession()sess.run(tf.initialize_all_variables())print(sess.run(a))#[[ 0.04261756 -0.34297419 -0.87816691 -0.15430689]# [ 0.18663144  0.86972666 -0.06103253  0.38307118]# [ 0.84588599 -0.45432305 -0.39736366  0.38526249]]print(sess.run(b))#[2 1 1 2]print(sess.run(c))#[0 1 0]

0 0