tf.nn.softmax_cross_entropy_with_logits()
来源:互联网 发布:centos 7服务器版安装 编辑:程序博客网 时间:2024/05/18 15:28
参考官方文档
format:softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, dim=-1, name=None)
Args:
_sentinel: Used to prevent positional parameters. Internal, do not use.(一般不用)
labels: Each row `labels[i]` must be a valid probability distribution.
logits: Unscaled log probabilities.
dim: The class dimension. Defaulted to -1 which is the last dimension.
name: A name for the operation (optional).
Returns:
A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the(也就是说返回的一个1维度的tensor,且它的数据类型和参数logits 一致)
`logits` and `labels` must have the same shape `[batch_size, num_classes]`and the same dtype (either `float16`, `float32`, or `float64`).(强调labels和logits必须有相同的数据类型和相同的shape,不然就会提示错误)
第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],num_classes就是分类的数量。单样本的话,大小就是num_classes
第二个参数labels:实际的标签,它的shape同上
顾名思义,具体的执行流程大概分为两步:
第一步是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率),多样本的话就是输出[batchsize,num_classes]大小的矩阵
softmax的公式是:
第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵cross_entropy,公式如下:
其中 指代实际的标签中第i个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)
就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值显而易见,预测 越准确,结果的值越小(别忘了前面还有负号),当输入是单样本的向量[Y1,Y2,Y3...],则显然经过交叉熵之后得到一个数,而如果输入是[batchsize,num_classes]矩阵的话,则输出是batchsize大小的向量,但是得到的并不是我们想要的loss,还要做一个队向量求平均才是我们想要的loss。所以还要有个tf.reduce_sum的操作,最后才得到一个loss的数值。
为了有个感性认识,把这个函数里的计算步骤分步执行,如下
import tensorflow as tf #our NN's output logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]]) #step1:do softmax y=tf.nn.softmax(logits) #true label #注意这里标签必须是浮点数,不然在后面计算tf.multiply时就会因为类型不匹配tf_log的float32数据类型而出错y_=tf.constant([[0,0,1.0],[0,0,1.0],[0,0,1.0]])#step2:do log tf_log=tf.log(y)#step3:do mult pixel_wise_mult=tf.multiply(y_,tf_log)#step4:do cross_entropy cross_entropy = -tf.reduce_sum(pixel_wise_mult) #do cross_entropy just two step cross_entropy2_step1=tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=logits)cross_entropy2_step2=tf.reduce_sum(cross_entropy2_step1)#dont forget tf.reduce_sum()!! with tf.Session() as sess: y_value,tf_log_value,pixel_wise_mult_value,cross_entropy_value=sess.run([y,tf_log,pixel_wise_mult,cross_entropy]) cross_entropy2_step1_value,cross_entropy2_step2_value=sess.run([cross_entropy2_step1,cross_entropy2_step2]) print("step1:softmax result=\n%s\n"%(y_value)) print("step2:tf_log_result result=\n%s\n"%(tf_log_value)) print("step3:pixel_mult=\n%s\n"%(pixel_wise_mult_value)) print("step4:cross_entropy result=\n%s\n"%(cross_entropy_value)) print("Function(softmax_cross_entropy_with_logits) result=\n%s\n"%(cross_entropy2_step1_value)) print("Function(tf.reduce_sum) result=\n%s\n"%(cross_entropy2_step2_value))
得到的结果为
显然,把tf.nn.softmax_cross_entropy_with_logits()分步骤执行,得到了相同的结果都是1,22282。但还是要强调,两个参数labels和logits必须有相同的数据类型和相同的shape,不然就会提示错误
- tf.nn.softmax_cross_entropy_with_logits
- tf.nn.softmax_cross_entropy_with_logits()
- tf.nn.softmax_cross_entropy_with_logits
- tf.nn.softmax_cross_entropy_with_logits
- tf.nn.softmax_cross_entropy_with_logits的用法
- [译] TF-api(3) tf.nn.softmax_cross_entropy_with_logits
- tensorflow源码 tf.nn.softmax_cross_entropy_with_logits & tf.nn.sparse_softmax_cross_entropy_with_log
- 对比两个函数tf.nn.softmax_cross_entropy_with_logits和tf.nn.sparse_softmax_cross_entropy_with_logits
- 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法
- tf.nn.softmax_cross_entropy_with_logits()笔记及交叉熵
- 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法
- Tensorflow函数:tf.nn.softmax_cross_entropy_with_logits 讲解
- 交叉熵tf.nn.softmax_cross_entropy_with_logits的用法
- TensorFlow 介绍 tf.nn.softmax_cross_entropy_with_logits 的用法
- TensorFlow学习---tf.nn.softmax_cross_entropy_with_logits的用法
- tensorflow学习:tf.nn.softmax_cross_entropy_with_logits()
- 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法
- 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法
- 【FZU
- android在java代码中绘制矩形框
- 中缀式变后缀式
- Android Studio 如何导入aar包
- 安装latexdiff
- tf.nn.softmax_cross_entropy_with_logits()
- 如何用eclipse创建你的第一个servlet小程序
- JS--Array的常用方法map、reduce、filter、forEach、indexOf
- oracle biee 12c windows install
- js获取当前时间
- 电子技术相关网站
- c++ char与二进制互转
- nyoj--290--动物统计加强版
- Uvalive 7503