逻辑回归总结

来源:互联网 发布:最心酸的一句话知乎 编辑:程序博客网 时间:2024/06/14 11:58

线性回归的鲁棒性很差,在整个是实数域内,其敏感性是一致的。
逻辑回归是一种减小预测范围,将预测值限制在[0,1]内。逻辑回归可以简单的理解为,在线性回归的基础上,套用了一个逻辑函数。其通常用于分类问题。例如,输出y>0.5时,认为是正类,否则为负类。由于将输出值限定在[0.1]内,所以可以认为输出值是一个概率。

1. 理论推导

线性回归:

y=θ0+θ1x1++θnxn=θTx

逻辑函数:
y=σ(x)=11+ex

所以,逻辑回归核函数:
y=11+eθTx=hθ(x)

该表达式一个对数几率的特性(逻辑回归也称为对数几率回归)
对数几率:
ln(y1y)=θTx(1)

如果限定分类标签y为{0,1},则概率函数:
P(y|x,θ)=(hθ(x))y(1hθ(x))1y

似然函数:
L(θ)=i=1NP(y(i)|x(i),θ)

取对数:
l(θ)=ln(L(θ)) =i=1Ny(i)ln(hθ(x(i)))+(1y(i))ln(1hθ(x(i))=i=1Ny(i)ln(hθ(x(i)1hθ(x(i)))+ln(1hθ(x(i))

代入等式(1)
l(θ)=i=1Ny(i)θTx+ln(1hθ(x(i))

求梯度,首先sigmoid函数导数为:
σ(x)=σ(1σ)x

则:
l(θ)θ  =i=1Ny(i)x(i)hθ(x(i))x(i)=i=1N(y(i)y^(i))x(i)=i=1NError(x(i))x(i)

最后,更新模型权重:
θ:=θ+αl(θ)θ

2. 代码实现

我们采用python实现以上过程。其中,读取数据文件,效果评估用到了sklearn包提供 的一些工具,数据操作用到了numpy。

#!/usr/bin/env python# -*- coding: utf-8 -*-  # Copyright (c) 2017 - xiongjiezk <xiongjiezk@163.com>from sklearn import datasets, metricsimport mathimport numpy as npclass LogitRegression:    def __init__(self, grad, learn_rate, epoch):        self.weights = []        self.bias = 0.0        self.grad = grad        self.learn_rate = learn_rate        self.epoch = epoch    def sigmoid(self, x):        return 1.0 / (1 + math.exp(-x))    def activate(self, x):        poly_value = np.dot(self.weights, x) + self.bias        return self.sigmoid(poly_value)    def rmse(self, y, y_pred):        return np.sum(np.square(np.subtract(y, y_pred)))    def fit(self, X, y):        self.weights = np.random.normal(size=np.shape(X[0]))        self.bias = 1.0        for iter in range(0, self.epoch):            y_pred = []            if self.grad == 'grad_decent':                for i in range(0, len(X)):                    y_ = self.activate(X[i])                    y_pred.append(y_)                delta = np.zeros(np.shape(self.weights))                for i in range(0, len(X)):                    delta += np.array(y[i] - y_pred[i]) * np.array(X[i])                self.weights += self.learn_rate * delta                self.bias += self.learn_rate * np.sum(np.subtract(y, y_pred))            elif self.grad == 'stoc_grad_decent':                for i in range(0, len(X)):                    y_ = self.activate(X[i])                    self.weights += self.learn_rate * (y[i] - y_) * np.array(X[i])                    self.bias += self.learn_rate * (y[i] - y_)                    y_pred.append(y_)            else:                pass            loss = self.rmse(y, y_pred)            if iter % 100 == 0:                print('current epoch: %s, loss: %s' % (iter, loss))    def predict(self, X):        scores = []        for i in range(0, len(X)):            scores.append(self.activate(X[i]))        class_ = np.array([0, 1])        indices = (np.array(scores) > 0.5).astype(np.int)        return np.array(class_[indices])if __name__ == '__main__':    data_and_labels = datasets.load_svmlight_file('E:/data/logit/train.txt')    X_train = np.reshape(data_and_labels[0].data, data_and_labels[0].shape)    y_train = data_and_labels[1]    test_and_labels = datasets.load_svmlight_file('E:/data/logit/test.txt')    X_test = np.reshape(test_and_labels[0].data, test_and_labels[0].shape)    y_test = test_and_labels[1]    logit = LogitRegression('grad_decent', 0.01, 1000)    logit.fit(X_train, y_train)    y_pred = logit.predict(X_test)    print("Classification report for classifier %s:\n%s\n"          % ([logit.weights, logit.bias], metrics.classification_report(y_test, y_pred)))    print("Confusion matrix:\n%s" % metrics.confusion_matrix(y_test, np.array(y_pred)))

运行效果如下:

current epoch: 800, loss: 2.45524645114current epoch: 900, loss: 2.45569467863Classification report for classifier [array([ 0.70676676, -1.68913992]), 12.64359503469959]:             precision    recall  f1-score   support        0.0       1.00      0.88      0.93         8        1.0       0.92      1.00      0.96        12avg / total       0.95      0.95      0.95        20Confusion matrix:[[ 7  1] [ 0 12]]

训练数据文件

0    1:-0.017612    2:14.0530641    1:-1.395634    2:4.6625410    1:-0.752157    2:6.5386200    1:-1.322371    2:7.1528530    1:0.423363    2:11.0546771    1:0.406704    2:7.0673350    1:0.667394    2:12.7414521    1:-2.460150    2:6.8668050    1:0.569411    2:9.5487550    1:-0.026632    2:10.4277431    1:0.850433    2:6.9203340    1:1.347183    2:13.1755001    1:1.176813    2:3.1670200    1:-1.781871    2:9.0979531    1:-0.566606    2:5.7490031    1:0.931635    2:1.5895051    1:-0.024205    2:6.1518231    1:-0.036453    2:2.6909881    1:-0.196949    2:0.4441651    1:1.014459    2:5.7543991    1:1.985298    2:3.2306191    1:-1.693453    2:-0.5575400    1:-0.576525    2:11.7789221    1:-0.346811    2:-1.6787301    1:-2.124484    2:2.6724710    1:1.217916    2:9.5970150    1:-0.733928    2:9.0986871    1:-3.642001    2:-1.6180871    1:0.315985    2:3.5239530    1:1.416614    2:9.6192321    1:-0.386323    2:3.9892861    1:0.556921    2:8.2949840    1:1.224863    2:11.5873601    1:-1.347803    2:-2.4060511    1:1.196604    2:4.9518510    1:0.275221    2:9.5436470    1:0.470575    2:9.3324880    1:-1.889567    2:9.5426620    1:-1.527893    2:12.1505790    1:-1.185247    2:11.3093181    1:-0.445678    2:3.2973031    1:1.042222    2:6.1051550    1:-0.618787    2:10.3209861    1:1.152083    2:0.5484671    1:0.828534    2:2.6760450    1:-1.237728    2:10.5490331    1:-0.683565    2:-2.1661251    1:0.229456    2:5.9219380    1:-0.959885    2:11.5553360    1:0.492911    2:10.9933240    1:0.184992    2:8.7214880    1:-0.355715    2:10.3259760    1:-0.397822    2:8.0583970    1:0.824839    2:13.7303431    1:1.507278    2:5.0278661    1:0.099671    2:6.8358390    1:-0.344008    2:10.7174851    1:1.785928    2:7.7186450    1:-0.918801    2:11.5602171    1:-0.364009    2:4.7473001    1:-0.841722    2:4.1190831    1:0.490426    2:1.9605390    1:-0.007194    2:9.0757920    1:0.356107    2:12.4478630    1:0.342578    2:12.2811621    1:-0.810823    2:-1.4660181    1:2.530777    2:6.4768010    1:1.296683    2:11.6075590    1:0.475487    2:12.0400350    1:-0.783277    2:11.0097250    1:0.074798    2:11.0236501    1:-1.337472    2:0.4683390    1:-0.102781    2:13.7636511    1:-0.147324    2:2.8748460    1:0.518389    2:9.8870350    1:1.015399    2:7.5718821    1:-1.658086    2:-0.0272551    1:1.319944    2:2.1712281    1:2.056216    2:5.0199811    1:-0.851633    2:4.375691

测试数据文件

0    1:-1.510047    2:6.061992    1    1:-1.076637    2:-3.181888   0    1:1.821096    2:10.283990    1    1:3.010150    2:8.401766     1    1:-1.099458    2:1.688274    1    1:-0.834872    2:-1.733869   1    1:-0.846637    2:3.849075    0    1:1.400102    2:12.628781    1    1:1.752842    2:5.468166     1    1:0.078557    2:0.059736     1    1:0.089392    2:-0.715300    0    1:1.825662    2:12.693808    0    1:0.197445    2:9.744638     1    1:0.126117    2:0.922311     1    1:-0.679797    2:1.220530    1    1:0.677983    2:2.556666     0    1:0.761349    2:10.693862    1    1:-2.168791    2:0.143632    0    1:1.388610    2:9.341997     0    1:0.317029    2:14.739025    

参考资料:
https://www.cnblogs.com/sxron/p/5489214.html
http://blog.csdn.net/programmer_wei/article/details/52072939