LogisticRegression Digital

来源:互联网 发布:守望先锋安娜技能数据 编辑:程序博客网 时间:2024/06/05 07:23
# -*- coding:utf-8 -*-import pandas as pdimport matplotlib.pyplot as pltfrom sklearn.linear_model import LogisticRegressionfrom sklearn.linear_model import LogisticRegressionCVfrom sklearn.model_selection import GridSearchCV, RandomizedSearchCVfrom sklearn import metricsfrom sklearn.decomposition import PCAdef showimg(dig_csv,index):    img = dig_csv.loc[index]    img_values = img.values[1:].reshape(28,28)    img_label = img.values[0]    plt.gray()     plt.matshow(img_values)    plt.show()    print(img_label)#加载数据集train_csv = pd.read_csv('dataset/train.csv')test_csv = pd.read_csv('dataset/test.csv')#csv解析出数据y_train,X_train = train_csv.values[:,0],train_csv.values[:,1:]X_test = test_csv.values#降维减少特征数,加快计算pca = PCA(n_components=0.95)pca_X_train = pca.fit_transform(X_train)pca_X_test = pca.transform(X_test)#交叉验证选择超参数,这里没有使用LogisticRegressionCV,因为我不知道如何得到最优的参数等。。。#lrCV = LogisticRegressionCV(Cs=[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],multi_class='multinomial',solver='lbfgs')grid_search_params = {'C': [0.18,0.19,0.2,0.21,0.22]}lr = LogisticRegression(multi_class='multinomial',solver='lbfgs')lrCV = GridSearchCV(lr,grid_search_params)lrCV.fit(pca_X_train,y_train)#预测pre_y_train = lrCV.predict(pca_X_train)pre_y_test = lrCV.predict(pca_X_test)#打印准确率和混淆矩阵print(metrics.accuracy_score(y_train,pre_y_train))print(metrics.confusion_matrix(y_train,pre_y_train))print(lrCV.best_params_)#按照格式写入csv文件df = pd.DataFrame(pre_y_test,index=range(1,len(pre_y_test)+1))df.to_csv('test.csv')#plt.matshow(X_test[1].reshape(28,28))

不知为何softmax识别手写体准确率很低,不知道怎么调整

原创粉丝点击