libsvm在python下的使用及用绘制ROC曲线

来源:互联网 发布:淘宝卖大米的营销策略 编辑:程序博客网 时间:2024/04/30 12:18

libsvm在python下的使用简单示例

1. 首先是训练svm,libsvm的环境配置问题简要说明一下,假设你已经正确安装好python,numpy(推荐使用Anaconda,这个集成环境直接安装好了python和numpy,还有一些其他很多很常见的包,比如matplotlib,省去了配置环境的麻烦,然后你就可以写python程序了,如果你不想再命令行里面写,可以下个pycharm 它是python开发的IDE)。

2. 假设你下载好了libsvm,解压到X:\libsvm-3.18,X代表你的盘符,比如C,D.然后请特别确认一下,你的OS,python和libsvm要么都是32位(x86),要么都是64位(x64),然后假入你的电脑是win7,然后将X:\libsvm-3.18\windows下的libsvm.dll复制到你的系统盘目录C:\Windows\System32下,这样就可以在python下导入X:\libsvm-3.18\python下的两个svm.py,svmutil.py文件了。如果你的电脑是win8 ,那个libsvm.dll是不能用的,还得自己重新生成dll。我用的办法是打开vs2012 x64本机工具命令提示,切换到libsvm目录下,使用下面命令就可以生成在X:\libsvm-3.18\windows\libsvm.dll了 命令是:nmake -f Makefile clean all

3. 假设你已经知道libsvm在python下要训练的数据格式了,额,还是简要介绍一下libsvm在python下的训练数据格式吧。设x代表你的训练数据,那么,它应该是一个列表,形如x=[[],[],...,[]]嵌套在里面的列表表示你的训练数据,举个例子,x=[[1,2],[3,4],[5,6]]说明一共有3(len(x))个训练样本,第一个训练样本是1,2,第二个训练样本是3,4,第三个训练样本是5,6。接下来,y是你的训练样本的标签,格式为y=[].比如,举个例子,y=[1,-1,1],说明你的第一个训练样本来自第一类,第二个来自第二类,第三个来自第一类。

4. libsvm训练,我按照如下方式训练,当然不止这一种吧

<span style="font-size:18px;">problem = svm_problem(label, feat)model = svm_train(problem,parameter)svm_save_model(self.result_save_dir+'/'+model_save_path,model)</span>

这里,label相当于y,feat相当于x,parameter是svm_train的字符串训练参数,比如'-c 4 -t 0 -b 1' ,-t 0表示我用的是线性核函数,-c 4表示松弛变量的惩罚因子是4, -b 1表示我要求test输出每个测试用例属于每个类的概率,这个参数的请在python命令行下导入svm,svmutil后输入help(svm_train)

5. libsvm测试,接刚才上面的那个例子:

<span style="font-size:18px;">p_label,p_acc,p_val = svm_predict(label,test_features, model,'-b 1')</span>

这里的model就是你训练时候保存的那个model.,这里为了后面绘制ROC曲线,必须在svm_predict后面加上'-b 1',p_val返回的就是我们想要的那个概率。

 ROC的绘制步骤及在python下的相关代码

1. 结合上面的示例,首先取得我们需要的两个参数,一个数测试样本真实的label,一个是样本属于正样本的概率。

<span style="font-size:18px;">decision = [x for x,y in p_val]#Get the first element for each list embeded in p+valroc_param = [[i,j] for i,j in zip(label,decision)]roc_param.sort(key = itemgetter(1),reverse=True)</span>

2. 接下来直接给出python绘制ROC的函数

<span style="font-size:18px;">def plotroc(roc_param,labels):    fpr = []    tpr = []    numCurves = len(roc_param)    colorTable = ['#FFC125','#FFFF00','#FF83FA','#FF1493','#F08080',                  '#C1FFC1','#BF3EFF','#969696','#0D0D0D']    print 'Maximum supported number of colors is 10!'    print 'draw ROC curve....'    assert(numCurves<=len(colorTable))    nPositive = len(roc_param[0])/2    nNegative = len(roc_param[0])/2    i = 0    while i < numCurves:        tp = 0.0        fp = 0.0        tempTpr = []        tempFpr = []        for x in roc_param[i]:            if x[0] >0:                tp += 1.0            else:                fp += 1.0            tempTpr.append(tp/float(nPositive))            tempFpr.append(fp/float(nNegative))        tpr.append(tempTpr)        fpr.append(tempFpr)        i += 1    assert len(tpr) == numCurves    assert len(fpr) == numCurves    for i in xrange(numCurves):        plt.plot(fpr[i],tpr[i],color=colorTable[i],linewidth=2,label=labels[i])    plt.xlabel("FPR", fontsize=14)    plt.ylabel("TPR", fontsize=14)    plt.title("ROC Curve", fontsize=14)    x = [0.0, 1.0]    plt.plot(x, x, linestyle='dashed', color='red', linewidth=2, label='random')    plt.xlim(0.0, 1.0)    plt.ylim(0.0, 1.0)    plt.legend(fontsize=10, loc='lower right')    plt.tight_layout()    plt.savefig('ROCCurve.png')</span>


简要解释一下上述的代码,roc_param就是我们刚才的roc_param,label就是测试样本真实的label。一定注意,这里绘制ROC个人在这里理解的是必须是两类问题,正样本或者负样本。这里,我的这个函数能够绘制多条ROC曲线在同一张图上。也就是说,roc_param是一个列表,每一个元素代表一条ROC曲线所需要的数据。label同理。比如,roc_param=[[data1],[data2]],表示我要绘制两条ROC曲线,一条的数据是data1,一条的是data2,只是这里data1可能也是一个列表。

0 0
原创粉丝点击