实习点滴(11)--TensorFlow快速计算“多分类问题”的混淆矩阵以及精确率、召回率、F1值、准确率

来源:互联网 发布:jquery与js的区别 编辑:程序博客网 时间:2024/06/07 17:42

        在机器学习中,我们会利用一些指标(混淆矩阵、精确率、召回率、F1值、准确率)来判断我们模型的好坏,从而改进优化模型。下面介绍如何在TensorFlow下快速计算这些指标。

        1、混淆矩阵

        confusion_matrix = tf.contrib.metrics.confusion_matrix(labels_pred_all, labels_all, num_classes=None, dtype=tf.int32, name=None, weights=None)        confusion_matrix = sess.run(confusion_matrix)

        因为第一步所计算出来的混淆矩阵是一个Tensor,所以需要进行转换。

        具体api详解:

        https://haosdent.gitbooks.io/tensorflow-document/content/api_docs/python/contrib.metrics.html#confusion_matrix

        值得注意的是:所计算出来的混淆矩阵,列是真实值(也就是期望值),行是预测值

        2、四大指标:

        有了混淆矩阵,计算四大指标就好办了。

        accu = [0,0,0,0,0]        column = [0,0,0,0,0]        line = [0,0,0,0,0]        accuracy = 0        recall = 0        precision = 0        for i in range(0,5):            accu[i] = confusion_matrix[i][i]        for i in range(0,5):           for j in range(0,5):               column[i]+=confusion_matrix[j][i]        for i in range(0,5):           for j in range(0,5):               line[i]+=confusion_matrix[i][j]        for i in range(0,5):            accuracy += float(accu[i])/len_labels_all        for i in range(0,5):            if column[i] != 0:                recall+=float(accu[i])/column[i]        recall = recall / 5        for i in range(0,5):            if line[i] != 0:                precision+=float(accu[i])/line[i]        precision = precision / 5        f1_score = (2 * (precision * recall)) / (precision + recall)
阅读全文
1 0
原创粉丝点击