[matplotlib] 绘制Cross-Validation的误差图

来源:互联网 发布:手机卡能注册几个淘宝 编辑:程序博客网 时间:2024/06/05 11:49

概述:

在调整模型参数的时候,往往会进行交叉验证(Cross-Validation)。绘制交叉验证的误差图。

数据:

k是需要调整的参数, 从k_choices中选取

k_choices = [1, 3, 5, 8, 10, 12, 15, 20, 50, 100]

假设经过验证以后k_to_accuracies字典里保存了k取不同值时多次验证的准确性:

k_to_accuracies = {    1: [0.24, 0.23, 0.24, 0.25, 0.29],    3: [0.17, 0.23, 0.32, 0.22, 0.23],    5: [0.12, 0.21, 0.27, 0.19, 0.18],    8: [0.13, 0.23, 0.26, 0.16, 0.2],    10: [0.16, 0.18, 0.24, 0.16, 0.19],    12: [0.17, 0.19, 0.24, 0.2, 0.26],    15: [0.17, 0.23, 0.19, 0.12, 0.14],     20: [0.12, 0.17, 0.19, 0.12, 0.2],    50: [0.2, 0.16, 0.17, 0.16, 0.14],     100: [0.16, 0.15, 0.19, 0.19, 0.19],}

绘图

绘图的代码如下:

for k in k_choices:  accuracies = k_to_accuracies[k]  plt.scatter([k] * len(accuracies), accuracies)# plot the trend line with error bars that correspond to standard deviationaccuracies_mean = np.array([np.mean(v) for k,v in sorted(k_to_accuracies.items())])accuracies_std = np.array([np.std(v) for k,v in sorted(k_to_accuracies.items())])plt.errorbar(k_choices, accuracies_mean, yerr=accuracies_std)plt.title('Cross-validation on k')plt.xlabel('k')plt.ylabel('Cross-validation accuracy')plt.show()

这里写图片描述

0 0