tf.nn.sparse_softmax_cross_entropy_with_logits()
来源:互联网 发布:java 响应式编程 编辑:程序博客网 时间:2024/05/29 07:21
参考官方文档
format:sparse_softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, name=None)
Args:
_sentinel: Used to prevent positional parameters. Internal, do not use.(这个参数一般不用)
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
`labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
must be an index in `[0, num_classes)`. Other values will raise an
exception when this op is run on CPU, and return `NaN` for corresponding
loss and gradient rows on GPU.(这个labels参数要注意,它的shape必须是`[d_0, d_1, ..., d_{r-1}]'(而参数logits的shape是[d_0, d_1, ..., d_{r-1},num_classes]`其中差别自行体会) 数据类型必须是int32或者int64,且在labels中的每个值必须是在[0,num_classes),否则在这个操作运行在cpu的时候将会出现exception,运行在GPU的时候将会返回'NaN',而不是返回loss了,这个情况我遇到过,所以在看到输出的不是loss值而是‘NaN’时就应该仔细检查一下这个函数中的labels有没有符合条件)
logits: Unscaled log probabilities of shape
`[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`.
name: A name for the operation (optional).(这里注意shape和数据类型必须是float32和float64,一般很容易搞错,而tf.nn.softmax_cross_entropy_with_logits求 数据类型可以是float16 ,`float32` or `float64`.)
Returns:
A `Tensor` of the same shape as `labels` and of the same type as `logits`
with the softmax cross entropy loss.(返回值要和tensor相同的shape和labels相同的数据类型)
这个函数和tf.nn.softmax_cross_entropy_with_logits函数比较明显的区别在于它的参数labels的不同,这里的参数label是非稀疏表示的,比如表示一个3分类的一个样本的标签,稀疏表示的形式为[0,0,1]这个表示这个样本为第3个分类,而非稀疏表示就表示为2(因为从0开始算,0,1,2,就能表示三类),同理[0,1,0]就表示样本属于第二个分类,而其非稀疏表示为1。tf.nn.sparse_softmax_cross_entropy_with_logits()比tf.nn.softmax_cross_entropy_with_logits多了一步将labels稀疏化的操作。因为深度学习中,图片一般是用非稀疏的标签的,所以用tf.nn.sparse_softmax_cross_entropy_with_logits()的频率比tf.nn.softmax_cross_entropy_with_logits高。
栗子
import tensorflow as tf #our NN's output logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]]) #step1:do softmax y=tf.nn.softmax(logits) #true label #注意这里标签必须是浮点数,不然在后面计算tf.multiply时就会因为类型不匹配tf_log的float32数据类型而出错y_=tf.constant([[0,0,1.0],[0,0,1.0],[0,0,1.0]])#这个是稀疏的标签#step2:do log tf_log=tf.log(y)#step3:do mult pixel_wise_mult=tf.multiply(y_,tf_log)#step4:do cross_entropy cross_entropy = -tf.reduce_sum(pixel_wise_mult) #do cross_entropy just two step #将标签稠密化dense_y=tf.arg_max(y_,1)cross_entropy2_step1=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=dense_y,logits=logits)cross_entropy2_step2=tf.reduce_sum(cross_entropy2_step1)#dont forget tf.reduce_sum()!! with tf.Session() as sess: y_value,tf_log_value,pixel_wise_mult_value,cross_entropy_value=sess.run([y,tf_log,pixel_wise_mult,cross_entropy]) sparse_cross_entropy2_step1_value,sparse_cross_entropy2_step2_value=sess.run([cross_entropy2_step1,cross_entropy2_step2]) print("step1:softmax result=\n%s\n"%(y_value)) print("step2:tf_log_result result=\n%s\n"%(tf_log_value)) print("step3:pixel_mult=\n%s\n"%(pixel_wise_mult_value)) print("step4:cross_entropy result=\n%s\n"%(cross_entropy_value)) print("Function(softmax_cross_entropy_with_logits) result=\n%s\n"%(sparse_cross_entropy2_step1_value)) print("Function(tf.reduce_sum) result=\n%s\n"%(sparse_cross_entropy2_step2_value))
- tf.nn.sparse_softmax_cross_entropy_with_logits()
- tf.nn.sparse_softmax_cross_entropy_with_logits
- tf.nn.sparse_softmax_cross_entropy_with_logits的用法
- 对比两个函数tf.nn.softmax_cross_entropy_with_logits和tf.nn.sparse_softmax_cross_entropy_with_logits
- tf.nn.sparse_softmax_cross_entropy_with_logits()函数的用法
- tensorflow-BatchNormalization(tf.nn.moments及tf.nn.batch_normalization)
- tf.nn
- tensorflow学习:tf.nn.softmax_cross_entropy_with_logits()
- tensorflow(二):tf.nn.conv2d
- tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)
- tf.nn.top_k() tf.nn.in_top_k()
- tf.nn.top_k() tf.nn.in_top_k()
- tf.nn.conv2d 实例
- tf.nn.max_pool 实例
- tf.nn.atrous_conv2d 实例
- tf.nn.conv2d()
- tf.nn.in_top_k()
- tf.nn.ctc_loss
- spark之13:提交应用的方法(spark-submit)
- Tree packing opentrains
- 1:角谷猜想(程序设计与算法(一)第四周测验(2017夏季))
- 数据切分概述
- 初试注解 自定义实现FindViewById
- tf.nn.sparse_softmax_cross_entropy_with_logits()
- 给图片添加文字水印
- scala中的泛型
- MyBatis 易错点 面试点
- ssh安全只允许用户从指定的IP登陆
- spring配置文件中xsd引用问题
- Android 文件类型
- Android学习(四)Service的学习(上)
- java使用redis发布和订阅消息