TensorFlow softmax VS sparse softmax

来源:互联网 发布:c 遍历二维数组 编辑:程序博客网 时间:2024/05/20 06:28

sparse_softmax_cross_entropy_with_logits VS softmax_cross_entropy_with_logits

这两者都是计算分类问题的softmax loss的,所以两者的输出应该是一样的,唯一区别是两者的labels输入形似不一样。


Difference

在tensorflow中使用softmax loss的时候,会发现有两个softmax cross entropy。刚开始很难看出什么差别,结合程序看的时候,就很容易能看出两者差异。总的来说两者都是计算分类问题的softmax交叉熵损失,而两者使用的标签真值的形式不同。

  • sparse_softmax_cross_entropy_with_logits:
    使用的是实数来表示类别,数据类型为int16,int32,或者 int64,标签大小范围为[0,num_classes-1],标签的维度为[batch_size]大小。

  • softmax_cross_entropy_with_logits
    使用的是one-hot二进制码来表示类别,数据类型为float16,float32,或者float64,维度为[batch_size, num_classes]。这里需要说明一下的时,标签数据类型并不是Bool型的。这是因为实际上在tensorflow中,softmax_cross_entropy_with_logits中的每一个类别是一个概率分布,tensorflow中对该模块的说明中明确指出了这一点,Each row labels[i] must be a valid probability distribution。很显然,one-hot的二进码也可以看是一个有效的概率分布。

另外stackoverflow上面对两者的区别有一个总结说得很清楚,可以参考一下。


Common

有一点需要注意的是,softmax_cross_entropy_with_logits和sparse_softmax_cross_entropy_with_logits中的输入都需要unscaled logits,因为tensorflow内部机制会将其进行归一化操作以提高效率,什么意思呢?就是说计算loss的时候,不要将输出的类别值进行softmax归一化操作,输入就是wTX+b的结果。

tensorflow的说明是这样的:
Warning: This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results.

至于为什么这样可以提高效率,简单地说就是把unscaled digits输入到softmax loss中在反向传播计算倒数时计算量更少,感兴趣的可以参考pluskid大神的博客Softmax vs. Softmax-Loss: Numerical Stability,博文里面讲得非常清楚了。另外说一下,看了大神的博文,不得不说大神思考问题和解决问题的能力真的很强!


Example

import tensorflow as tf#batch_size = 2labels = tf.constant([[0, 0, 0, 1],[0, 1, 0, 0]])logits = tf.constant([[-3.4, 2.5, -1.2, 5.5],[-3.4, 2.5, -1.2, 5.5]])loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)loss_s = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(labels,1), logits=logits)with tf.Session() as sess:      print "softmax loss:", sess.run(loss)    print "sparse softmax loss:", sess.run(loss_s)

Output:
softmax loss: [ 0.04988896 3.04988885]
sparse softmax loss: [ 0.04988896 3.04988885]


Reference

tensorflow:softmax_cross_entropy_with_logits
tensorflow:sparse_softmax_cross_entropy_with_logits
stackoverflow
Softmax vs. Softmax-Loss: Numerical Stability