Logistic回归实现鸢尾花分类

来源:互联网 发布:2016淘宝刷单怎么操作 编辑:程序博客网 时间:2024/05/04 03:29

分类效果:


数据集示例:

5.1,3.5,1.4,0.2,Iris-setosa4.9,3.0,1.4,0.2,Iris-setosa4.7,3.2,1.3,0.2,Iris-setosa4.6,3.1,1.5,0.2,Iris-setosa5.0,3.6,1.4,0.2,Iris-setosa5.4,3.9,1.7,0.4,Iris-setosa4.6,3.4,1.4,0.3,Iris-setosa5.0,3.4,1.5,0.2,Iris-setosa4.4,2.9,1.4,0.2,Iris-setosa4.9,3.1,1.5,0.1,Iris-setosa5.4,3.7,1.5,0.2,Iris-setosa4.8,3.4,1.6,0.2,Iris-setosa4.8,3.0,1.4,0.1,Iris-setosa4.3,3.0,1.1,0.1,Iris-setosa5.8,4.0,1.2,0.2,Iris-setosa

#coding:utf-8import numpy as npfrom sklearn.linear_model import LogisticRegressionimport matplotlib.pyplot as pltimport matplotlib as mplfrom sklearn import preprocessingimport pandas as pdfrom sklearn.preprocessing import StandardScalerfrom sklearn.pipeline import Pipelineif __name__=="__main__":    path='10.iris.data'    data=pd.read_csv(path,header=None)    iris_types=data[4].unique()    print 'iris_types----'    print iris_types    print 'data[4]----'    print data[4]    print 'data----'    print data    print '----'    print 'i,type'    for i,type in enumerate(iris_types):        print i,type        data.set_value(data[4]==type,4,i)    x,y=np.split(data.values,(4,),axis=1)    x=x.astype(np.float)    y=y.astype(np.int)    #仅使用前两列特征    x=x[:,:2]    lr=Pipeline([('sc',StandardScaler()),('clf',LogisticRegression())])    lr.fit(x,y.ravel())    y_hat=lr.predict(x)    #概率    y_hat_prob=lr.predict_proba(x)    # print 'y_hat=\n',y_hat    # print 'y_hat_prob = \n', y_hat_prob    # print u'准确度:%.2f%%' % (100*np.mean(y_hat == y.ravel()))    #画图    N,M=500,500 #纵横各采样多少个值    x1_min,x1_max=x[:,0].min(),x[:,0].max() #第0列的范围    x2_min,x2_max=x[:,1].min(),x[:,1].max() #第一列的范围    t1=np.linspace(x1_min,x1_max,N)    t2=np.linspace(x2_min,x2_max,M)    x1,x2=np.meshgrid(t1,t2)                #生成网格采样点    x_test=np.stack((x1.flat,x2.flat),axis=1) #测试点    mpl.rcParams['font.sans-serif']=[u'simHei']    mpl.rcParams['axes.unicode_minus']=False    cm_light=mpl.colors.ListedColormap(['#77E0A0', '#FF8080', '#A0A0FF'])    cm_dark=mpl.colors.ListedColormap(['g','r','b'])    y_hat=lr.predict(x_test)    #预测值    y_hat=y_hat.reshape(x1.shape)   #使之与输入的形状相同    plt.figure(facecolor='w')    plt.pcolormesh(x1,x2,y_hat,cmap=cm_light)   #预测值的显示    plt.scatter(x[:,0],x[:,1],c=y,edgecolors='k',s=50,cmap=cm_dark) #样本的显示    plt.xlabel(u'花萼长度',fontsize=14)    plt.ylabel(u'花萼宽度',fontsize=14)    plt.xlim(x1_min,x1_max)    plt.ylim(x2_min,x2_max)    plt.grid() #显示网格    plt.title(u'鸢尾花Logistic回归分类效果 - 标准化',fontsize=17)    plt.show()


完整代码下载地址:http://download.csdn.net/detail/hb707934728/9810808

0 0
原创粉丝点击