scikit-learn/logistic regression识别mnist

来源:互联网 发布:淘宝古玩网 编辑:程序博客网 时间:2024/06/04 23:19

逻辑回归的对数似然损失函数cost function:
这里写图片描述
当y=1时,假定这个样本为正类。如果此时hθ(x)=1,则单对这个样本而言的cost=0,表示这个样本的预测完全准确。那如果所有样本都预测准确,总的cost=0
但是如果此时预测的概率hθ(x)=0,那么cost趋向于无穷大,那么此时就要对损失函数加一个很大的惩罚项。
logistic cost定义为:
这里写图片描述

logistic回归主要是对二分类问题进行分类,如果要用这个算法对多分类问题进行分类,就需要建立相应变量,将多分类问题处理成多个二分类问题,再进行分类。softmax回归可以直接解决多分类问题,对于类别从1到k的问题,可以判别样本属于哪个类。
分类原则:
如果多个类别中,各个类别是可以相互包含的则使用logistic回归,如果是互斥的则使用softmax回归

使用logistic regression 识别mnist

#coding:utf-8"""python 3sklearn 0.18"""from sklearn.model_selection import GridSearchCVfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import accuracy_score,confusion_matrix,classification_reportfrom sklearn.linear_model import LogisticRegressionimport input_dataimport numpy as npimport pickleimport datetimestart_time = datetime.datetime.now()mnist = input_data.read_data_sets('mnist/',one_hot=False)x = mnist.train.imagesy = mnist.train.labels#采用交叉验证train_data,validation_data,train_labels,validation_labels = train_test_split(x,y,test_size=0.2)#训练一个LogisticRegression分类器clf = LogisticRegression(penalty='l2',tol=0.001)clf.fit(train_data,train_labels)predictions = []for i in range(1000):    if i % 100 == 0:        print('======>>>>>>','epoch:',int(i/100))    #将预测的结果存入prediction    output = clf.predict([mnist.test.images[i]])    predictions.append(output)#混淆矩阵print(confusion_matrix(mnist.test.labels[0:1000],predictions))#classification_reportprint(classification_report(mnist.test.labels[0:1000],np.array(predictions)))print('test accuracy is:',accuracy_score(mnist.test.labels[0:1000],predictions))with open('logistic.pickle','wb') as f:    pickle.dump(clf,f)end_time = datetime.datetime.now()print('total time is :',(end_time - start_time).seconds)

结果

测试集上识别准确率达到90%
这里写图片描述

优化

#coding:utf-8"""python 3 scikit-learn 0.18"""from sklearn.model_selection import GridSearchCVfrom sklearn.model_selection import train_test_splitfrom sklearn.linear_model import LogisticRegressionimport input_dataimport numpy as npimport pickleimport datetimestart_time = datetime.datetime.now()mnist = input_data.read_data_sets('mnist',one_hot=False)x = mnist.train.images[0:1000,:]y = mnist.train.labels[0:1000]train_data,validation_data,train_labels,validation_labels = train_test_split(x,y,test_size=0.1)#使用GridSearchCV调节最优参数clf = LogisticRegression()penalty_options = ['l1','l2']#solver_options = ['liblinear','newton_cg','lbfgs','sag']tol_options = [0.0001,0.00001,0.000001,0.0000001]param_options = dict(penalty=penalty_options,tol=tol_options)gridlog = GridSearchCV(clf,param_options,cv=10,scoring='accuracy',verbose=1)gridlog.fit(train_data,train_labels)print('best score is: ',str(gridlog.best_score_))print('best params are: ',str(gridlog.best_params_))end_time = datetime.datetime.now()total_time = (end_time - start_time).secondsprint('total time is: ',total_time)

结果

这里写图片描述

阅读全文
0 0
原创粉丝点击