在python,Scikit-learn的SVM算法

来源:互联网 发布:开淘宝网店的流程视频 编辑:程序博客网 时间:2024/05/21 14:46
# -*- coding: utf-8 -*-"""Created on Mon Aug 21 12:51:31 2017@author: Administrator"""from sklearn.ensemble import RandomForestClassifier import pandas as pdimport numpy as npimport pandas as pdimport xgboost as xgbfrom pandas import DataFrameimport matplotlib.pyplot as pltfrom sklearn.svm import SVC  ## 读取训练集的数据train_data = pd.read_table('X_train.txt',header=None,encoding='gb2312',delim_whitespace=True,index_col=0)predict_data= pd.read_table('X_test.txt',header=None,encoding='gb2312',delim_whitespace=True,index_col=0)train_label= pd.read_table('y_train.txt',header=None,encoding='gb2312',delim_whitespace=True,index_col=0)predict_label= pd.read_table('y_test.txt',header=None,encoding='gb2312',delim_whitespace=True,index_col=0)train_label=train_label.reset_index()-1predict_label=predict_label.reset_index()-1def random_tree_predict_accurate(m,n,train_data,train_label,predict_data,predict_label):    sample_number=m*train_data.shape[0]    sample_aspect=n*train_data.shape[1]    sample_train_data=train_data.iloc[:int(sample_number),:int(sample_aspect)]    sample_train_label=train_label.iloc[:int(sample_number),:]    predict_label_column=list(predict_label.columns)    predict_label.rename(columns={predict_label_column[0]: 'label'}, inplace=True)     clf= SVC()    bst=clf.fit(sample_train_data, sample_train_label)    ypred1=bst.predict(predict_data)       ypred1 = DataFrame(ypred1)    ypred1_column=list(ypred1.columns)    ypred1.rename(columns={ypred1_column[0]: 'label'}, inplace=True)     ac_middle=predict_label-ypred1    ac_middle_number=ac_middle[ac_middle['label']==0]    accurate=float(ac_middle_number.shape[0])/predict_label.shape[0]        return accurateaccurate_vary=[]m=0.001while m<=1:    print m    # 测试集样本的个数    m=m+0.1    #测试集特征的个数    n=1    #用xgboost算法    accurate=random_tree_predict_accurate(m,n,train_data,train_label,predict_data,predict_label)    accurate_vary.append(accurate)plt.xlabel('number of data')plt.ylabel('accurate')        plt.plot(accurate_vary,color="red",  linewidth=2.5, linestyle="-") 

参考的文献:
http://www.cnblogs.com/harvey888/p/5852687.html
http://blog.csdn.net/zouxy09/article/details/17292011

原创粉丝点击