[译] TF-api(3) tf.nn.softmax_cross_entropy_with_logits

来源:互联网 发布:php清除所有cookie 编辑:程序博客网 时间:2024/05/22 06:37

tf.nn.softmax_cross_entropy_with_logits

args:

_sentinel: Used to prevent positional parameters. Internal, do not use.

从源码里面来看,这个参数的目的是不让用,因为如果你给它传了值,它会raise一个error出来。所以在传值的时候要指定logits和labels,就是帮助你别写错代码的。

labels: Each row labels[i] must be a valid probability distribution.

这是标签,一般是one hot的表示形式

logits: Unscaled log probabilities.

这个是输入tensor,也就是模型最后一层全连接的输出,shape一般是[batch_size, nb_class]。

dim: The class dimension. Defaulted to -1 which is the last dimension.

这个dim表示的是nb_class

name: A name for the operation (optional).


tf.nn.softmax_cross_entropy_with_logits()这个api实际上等同于tf.nn.softmax(),tf.log(),以及tf.reduce_sum()的组合。这个api实现的步骤主要分为三步:

  • 将input用softmax概率归一化:

    p(xj)=exjni=1exi

  • 与标签做交叉熵:

    h(xj)=ylabellogp(xj)

  • 将数据轴的数据进行加和:

    H(x)=tf.reducesum(h(x),axis=1)

    代码验证:

import tensorflow as tfdata = [[i for i in range(4)] for j in range(4)]label = [[0., 0., 1., 0.] for j in range(4)]x = tf.constant(value=data, dtype=tf.float32)y_label = tf.constant(value=label, dtype=tf.float32)H1 = tf.nn.softmax_cross_entropy_with_logits(logits=x, labels=y_label)p = tf.nn.softmax(x)h = -y_label * tf.log(p)H2 = tf.reduce_sum(h, axis=1)with tf.Session() as sess:    print(sess.run(H1))    print(sess.run(H2))

如果要计算loss的话需要再套一个tf.reduce_sum(),因为这个api返回的是一个向量,而不是一个单值。

阅读全文
0 0
原创粉丝点击