custom python matlabplot

来源:互联网 发布:samba端口号作用 编辑:程序博客网 时间:2024/06/04 17:43

Add Parameters:


import matplotlib.pyplot as pltimport numpy as npfrom itertools import cyclefrom sklearn import svm, datasetsfrom sklearn.metrics import precision_recall_curvefrom sklearn.metrics import average_precision_scorefrom sklearn.model_selection import train_test_splitfrom sklearn.preprocessing import label_binarizefrom sklearn.multiclass import OneVsRestClassifier#load datay_test = np.loadtxt('./y_test_201708101147.txt')y_score = np.loadtxt('./y_score_201708101147.txt')# Compute Precision-Recall and plot curveprecision = dict()recall = dict()threshold = dict()for i in range(3):    precision[i], recall[i], threshold[i] = precision_recall_curve(y_test[:, i], y_score[:, i])# setup plot detailsparams = {'legend.fontsize': 30,          'font.family':'sans-serif',          'font.sans-serif':'Calibri',          'figure.figsize': (15, 15),         'axes.labelsize': 40,         'axes.titlesize':40,         'xtick.labelsize':30,         'ytick.labelsize':30}plt.rcParams.update(params)colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])className = ['s1', 'm1', 'max_s1_m1']markerName = ['|', 'x', '^']lw = 2# Plot Precision-Recall curve for each classplt.clf()#plt.figure(figsize=(15,15))#for i, color in zip(range(3), colors):#    plt.plot(recall[i], precision[i], '{0}'.format(styleName[i]), markersize=4, label='Precision-Recall curve of {0}'.format(className[i]))plt.plot(recall[0], precision[0], 'r^', markersize=6, label='Precision-Recall curve of s1')plt.plot(recall[1], precision[1], 'b.', markersize=4, label='Precision-Recall curve of m1')plt.plot(recall[2], precision[2], 'y--', markersize=6, label='Precision-Recall curve of max_s1_m1')    plt.plot(recall[0][:recall[0].size-1], threshold[0], color='cornflowerblue', lw=lw, label='Threshold-Recall curve of s1')plt.plt.xlim([0.0, 1.0])plt.ylim([0.0, 1.05])plt.xlabel('Recall')plt.ylabel('Precision/Threshold')plt.title('Precision-Recall curve to multi-class')plt.legend(loc="lower right")plt.savefig('tmp.pdf')plt.show()


原创粉丝点击