Tensorflow实现经典损失函数

来源:互联网 发布:mac配置jenkins 编辑:程序博客网 时间:2024/05/22 06:40

Tensorflow实现经典损失函数

神经网络模型的效果以及优化的目标是通过损失函数(Loss function)来定义的。如何判断一个输出向量和期望的向量有多接近呢?

交叉熵(cross entropy)

交叉熵刻画了两个概率分布之间的距离,它是分类问题中使用比较广的一种损失函数。给定两个概率分布p和q,通过q来表示p的交叉熵为:

H(p,q)=p(x)logq(x)

如何将神经网络前向传播得到的结果也变成概率分布呢?Softmax回归就是一个非常常用的方法。

Softmax回归本身可以作为一个学习算法来优化分类结果,但在Tensorflow中,Softmax回归的参数被去掉了,它只是一层额外的处理层,将神经网络的输出变成一个概率分布。假设原始的神经网络的输出为y1,y2,...,yn,那么经过Softmax回归处理之后的输出为:

softmax(y)i=yi=eyinj=1eyj

Tensorflow实现交叉熵

cross_entropy = -tf.reduce_mean(    y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))

其中y_代表正确结果,y代表预测结果,tf.clip_by_value函数可以将一个张量中的数值限制在一个范围之内,tf.log完成对张量中所有元素依次求对数,* 实现两个矩阵元素之间直接相乘(矩阵乘法需要使用tf.matmul函数来完成)。

因为交叉熵一般会与Softmax回归一起使用,所以Tensorflow对这两个功能进行了统一封装,并提供了tf.nn.softmax_cross_entropy_with_logits函数。比如可以直接通过下面的代码来实现使用了softmax回归之后的交叉熵损失函数:

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y, y_)

均方误差(MSE)

对于回归问题,最常用的损失函数是均方误差(MSE):

MSE(y,y)=ni=1(yiyi)2n

其中yi为一个batch中第i个数据的正确答案,而y为神经网络给出的预测值。一下代码展示了如何通过Tensorflow实现均方误差损失函数:

mse = tf.reduce_mean(tf.square(y_ - y))

自定义损失函数

通过自定义损失函数的方法,我们可以使得神经网络优化的结果更加接近实际问题的需求。例如我们在预测商品销售问题中使用的损失函数:

Loss(y,y)=i=1nf(yi,yi),   f(x,y)={a(xy)b(yx)x>yxy

在Tensorflow中可以通过以下代码实现这个损失函数:

loss = tf.reduce_sum(tf.select(tf.greater(v1, v2),                     (v1 - v2) * a, (v2-v1) * b))

tf.greater的输入是两个张量,此函数会比较这两个张量中每一个元素的大小,并返回比较结果。
tf.select函数有三个参数,第一个为选择条件,当选择条件为True时,tf.select函数会选择第二个参数中的值,否则使用第三个参数中的值。
在定义了损失函数之后,下面一个简单的神经网络程序来讲解如何利用自定义损失函数:

loss_less = 1loss_more = 10loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_) * loss_more, (y_ - y) * loss_less))train_step = tf.train.AdamOptimizer(0.001).minimize(loss)with tf.Session() as sess:    init_op = tf.global_variables_initializer()    sess.run(init_op)    STEPS = 5000    for i in range(STEPS):        start = (i*batch_size) % 128        end = (i*batch_size) % 128 + batch_size        sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})        if i % 1000 == 0:            print("After %d training step(s), w1 is: " % (i))            print sess.run(w1), "\n"    print "Final w1 is: \n", sess.run(w1)
原创粉丝点击